[Model] Add multi-image support for minicpmv (#7122)

Co-authored-by: hezhihui <hzh7269@modelbest.cn>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Alphi 2024-08-05 09:23:17 +08:00 committed by GitHub
parent f80ab3521c
commit 7b86e7c9cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 172 additions and 37 deletions

View File

@ -3,7 +3,7 @@ import gc
import os
import sys
from collections import UserList
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
import pytest
import torch
@ -508,7 +508,8 @@ class VllmRunner:
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,

View File

@ -14,6 +14,18 @@ from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
class NestedInputs(UserDict):
def __init__(self, model_inputs: BatchFeature):
super().__init__({"model_inputs": model_inputs})
self.model_inputs = model_inputs
def to(self, device: torch.types.Device):
return NestedInputs(self.model_inputs.to(device))
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
@ -23,7 +35,7 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"cherry_blossom":
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
"(<image>./</image>)\nWhat is the season?<|eot_id|>" \
"<|start_header_id|>assistant<|end_header_id|>\n\n"
"<|start_header_id|>assistant<|end_header_id|>\n\n",
})
models = ["openbmb/MiniCPM-Llama3-V-2_5"]
@ -94,22 +106,10 @@ def run_test(
]
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
class NestedInputs(UserDict):
def __init__(self, model_inputs: BatchFeature):
super().__init__({"model_inputs": model_inputs})
self.model_inputs = model_inputs
def to(self, device: torch.types.Device):
return NestedInputs(self.model_inputs.to(device))
hf_processor = hf_model.processor
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
@ -161,3 +161,123 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
HF_MULTIIMAGE_IMAGE_PROMPT = \
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
"(<image>./</image>)\n(<image>./</image>)\n" \
"Describe these images.<|eot_id|>" \
"<|start_header_id|>assistant<|end_header_id|>\n\n"
def run_multi_image_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]
inputs_per_case = [
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
[[rescale_image_size(image, factor) for image in images]
for factor in size_factors])
]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=1,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
tokenizer = vllm_model.model.get_tokenizer()
stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
vllm_outputs_per_case = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
stop_token_ids=stop_token_ids)
for prompts, images in inputs_per_case
]
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
hf_processor = hf_model.processor
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)
hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
tokenizer=tokenizer)
for prompts, images in inputs_per_case
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
vllm_outputs_per_case):
check_logprobs_close(
outputs_0_lst=[
trunc_hf_output(hf_output) for hf_output in hf_outputs
],
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
run_multi_image_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -392,6 +392,20 @@ class Resampler2_5(BaseResampler):
return x
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
version_float = getattr(config, "version", None)
# The old configs do not include version number
# TODO: Remove this after the HF repos are updated
if version_float is None:
if config.hidden_size == 2304 and config.query_num == 64:
return (2, 0)
return (2, 5)
version_str = str(version_float)
return tuple(int(x) for x in version_str.split("."))
def get_max_minicpmv_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig)
return getattr(hf_config, "query_num", 64)
@ -421,36 +435,43 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
version = get_version_by_config(model_config.hf_config)
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
image_processor = cached_get_image_processor(model_config.tokenizer)
def get_placeholder(image_size: Tuple[int, int], num_image: int):
if version == (2, 0) or version == (2, 5):
return image_processor. \
get_slice_image_placeholder(image_size)
return image_processor. \
get_slice_image_placeholder(image_size, num_image)
prompt = llm_inputs.get("prompt")
if prompt is None:
token_ids = llm_inputs.get("prompt_token_ids")
prompt = tokenizer.decode(token_ids)
image_processor = cached_get_image_processor(model_config.tokenizer)
pattern = "(<image>./</image>)"
image = multi_modal_data["image"]
images = multi_modal_data["image"]
if isinstance(images, Image.Image):
images = [images]
image_tags = re.findall(pattern, prompt)
if len(image_tags) == 0:
new_token_ids = token_ids
new_prompt = prompt
else:
if len(image_tags) > 1:
logger.warning("Multiple image input is not supported yet, "
"so any extra image tokens will be treated "
"as plain text.")
text_chunks = prompt.split(pattern)
new_prompt = (text_chunks[0] +
image_processor.get_slice_image_placeholder(image.size) +
"".join(text_chunks[1:]))
new_prompt_chunks: List[str] = []
for i in range(len(images)):
new_prompt_chunks += [
text_chunks[i],
get_placeholder(images[i].size, i)
]
new_prompt_chunks.append(text_chunks[-1])
new_prompt = "".join(new_prompt_chunks)
new_token_ids = tokenizer.encode(new_prompt)
llm_inputs = LLMInputs(
@ -478,14 +499,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
self.config = config
self.multimodal_config = multimodal_config
if not hasattr(self.config, "version"):
if self.config.hidden_size == 2304 and self.config.query_num == 64:
self.version = (2, 0)
else:
self.version = (2, 5)
else:
self.version = str(self.config.version).split(".")
self.version = tuple([int(x) for x in self.version])
self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config, cache_config, quant_config)
self.vpm = self.init_vision_module()
param_dtype = torch.get_default_dtype()

View File

@ -113,7 +113,7 @@ class ImagePlugin(MultiModalPlugin):
def _default_input_mapper(self, ctx: InputContext,
data: object) -> MultiModalInputs:
model_config = ctx.model_config
if isinstance(data, Image.Image):
if isinstance(data, (Image.Image, list)):
image_processor = self._get_hf_image_processor(model_config)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "