[Model] Add multi-image support for minicpmv (#7122)
Co-authored-by: hezhihui <hzh7269@modelbest.cn> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
f80ab3521c
commit
7b86e7c9cd
@ -3,7 +3,7 @@ import gc
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from collections import UserList
|
from collections import UserList
|
||||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar
|
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -508,7 +508,8 @@ class VllmRunner:
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
images: Optional[List[Image.Image]] = None,
|
images: Optional[Union[List[Image.Image],
|
||||||
|
List[List[Image.Image]]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||||
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
||||||
|
@ -14,6 +14,18 @@ from .utils import check_logprobs_close
|
|||||||
|
|
||||||
pytestmark = pytest.mark.vlm
|
pytestmark = pytest.mark.vlm
|
||||||
|
|
||||||
|
|
||||||
|
class NestedInputs(UserDict):
|
||||||
|
|
||||||
|
def __init__(self, model_inputs: BatchFeature):
|
||||||
|
super().__init__({"model_inputs": model_inputs})
|
||||||
|
|
||||||
|
self.model_inputs = model_inputs
|
||||||
|
|
||||||
|
def to(self, device: torch.types.Device):
|
||||||
|
return NestedInputs(self.model_inputs.to(device))
|
||||||
|
|
||||||
|
|
||||||
# The image token is placed before "user" on purpose so that the test can pass
|
# The image token is placed before "user" on purpose so that the test can pass
|
||||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||||
"stop_sign":
|
"stop_sign":
|
||||||
@ -23,7 +35,7 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
|||||||
"cherry_blossom":
|
"cherry_blossom":
|
||||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
|
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
|
||||||
"(<image>./</image>)\nWhat is the season?<|eot_id|>" \
|
"(<image>./</image>)\nWhat is the season?<|eot_id|>" \
|
||||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||||
})
|
})
|
||||||
|
|
||||||
models = ["openbmb/MiniCPM-Llama3-V-2_5"]
|
models = ["openbmb/MiniCPM-Llama3-V-2_5"]
|
||||||
@ -94,22 +106,10 @@ def run_test(
|
|||||||
]
|
]
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
|
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
|
||||||
|
|
||||||
class NestedInputs(UserDict):
|
|
||||||
|
|
||||||
def __init__(self, model_inputs: BatchFeature):
|
|
||||||
super().__init__({"model_inputs": model_inputs})
|
|
||||||
|
|
||||||
self.model_inputs = model_inputs
|
|
||||||
|
|
||||||
def to(self, device: torch.types.Device):
|
|
||||||
return NestedInputs(self.model_inputs.to(device))
|
|
||||||
|
|
||||||
hf_processor = hf_model.processor
|
hf_processor = hf_model.processor
|
||||||
hf_model.processor = lambda **kw: NestedInputs(
|
hf_model.processor = lambda **kw: NestedInputs(
|
||||||
hf_processor(**kw) # type: ignore
|
hf_processor(**kw) # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
hf_outputs_per_image = [
|
hf_outputs_per_image = [
|
||||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
@ -161,3 +161,123 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
|||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
HF_MULTIIMAGE_IMAGE_PROMPT = \
|
||||||
|
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
|
||||||
|
"(<image>./</image>)\n(<image>./</image>)\n" \
|
||||||
|
"Describe these images.<|eot_id|>" \
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def run_multi_image_test(
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
image_assets: _ImageAssets,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
size_factors: List[float],
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Inference result should be the same between hf and vllm.
|
||||||
|
|
||||||
|
All the image fixtures for the test is under tests/images.
|
||||||
|
For huggingface runner, we provide the PIL images as input.
|
||||||
|
For vllm runner, we provide MultiModalDataDict objects
|
||||||
|
and corresponding vision language config as input.
|
||||||
|
Note, the text input is also adjusted to abide by vllm contract.
|
||||||
|
The text output is sanitized to be able to compare with hf.
|
||||||
|
"""
|
||||||
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
|
inputs_per_case = [
|
||||||
|
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||||
|
[[rescale_image_size(image, factor) for image in images]
|
||||||
|
for factor in size_factors])
|
||||||
|
]
|
||||||
|
|
||||||
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
|
||||||
|
# max_model_len should be greater than image_feature_size
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=1,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enforce_eager=True) as vllm_model:
|
||||||
|
tokenizer = vllm_model.model.get_tokenizer()
|
||||||
|
stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
|
||||||
|
vllm_outputs_per_case = [
|
||||||
|
vllm_model.generate_greedy_logprobs(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=images,
|
||||||
|
stop_token_ids=stop_token_ids)
|
||||||
|
for prompts, images in inputs_per_case
|
||||||
|
]
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
|
||||||
|
hf_processor = hf_model.processor
|
||||||
|
hf_model.processor = lambda **kw: NestedInputs(
|
||||||
|
hf_processor(**kw) # type: ignore
|
||||||
|
)
|
||||||
|
hf_outputs_per_case = [
|
||||||
|
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=images,
|
||||||
|
tokenizer=tokenizer)
|
||||||
|
for prompts, images in inputs_per_case
|
||||||
|
]
|
||||||
|
|
||||||
|
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
||||||
|
vllm_outputs_per_case):
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=[
|
||||||
|
trunc_hf_output(hf_output) for hf_output in hf_outputs
|
||||||
|
],
|
||||||
|
outputs_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"size_factors",
|
||||||
|
[
|
||||||
|
# No image
|
||||||
|
[],
|
||||||
|
# Single-scale
|
||||||
|
[1.0],
|
||||||
|
# Single-scale, batched
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
# Multi-scale
|
||||||
|
[0.25, 0.5, 1.0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
||||||
|
size_factors, dtype: str, max_tokens: int,
|
||||||
|
num_logprobs: int) -> None:
|
||||||
|
run_multi_image_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
image_assets,
|
||||||
|
model,
|
||||||
|
size_factors=size_factors,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
@ -392,6 +392,20 @@ class Resampler2_5(BaseResampler):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
||||||
|
version_float = getattr(config, "version", None)
|
||||||
|
|
||||||
|
# The old configs do not include version number
|
||||||
|
# TODO: Remove this after the HF repos are updated
|
||||||
|
if version_float is None:
|
||||||
|
if config.hidden_size == 2304 and config.query_num == 64:
|
||||||
|
return (2, 0)
|
||||||
|
return (2, 5)
|
||||||
|
|
||||||
|
version_str = str(version_float)
|
||||||
|
return tuple(int(x) for x in version_str.split("."))
|
||||||
|
|
||||||
|
|
||||||
def get_max_minicpmv_image_tokens(ctx: InputContext):
|
def get_max_minicpmv_image_tokens(ctx: InputContext):
|
||||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||||
return getattr(hf_config, "query_num", 64)
|
return getattr(hf_config, "query_num", 64)
|
||||||
@ -421,36 +435,43 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||||
return llm_inputs
|
return llm_inputs
|
||||||
|
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
|
version = get_version_by_config(model_config.hf_config)
|
||||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
|
image_processor = cached_get_image_processor(model_config.tokenizer)
|
||||||
|
|
||||||
|
def get_placeholder(image_size: Tuple[int, int], num_image: int):
|
||||||
|
if version == (2, 0) or version == (2, 5):
|
||||||
|
return image_processor. \
|
||||||
|
get_slice_image_placeholder(image_size)
|
||||||
|
return image_processor. \
|
||||||
|
get_slice_image_placeholder(image_size, num_image)
|
||||||
|
|
||||||
prompt = llm_inputs.get("prompt")
|
prompt = llm_inputs.get("prompt")
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
token_ids = llm_inputs.get("prompt_token_ids")
|
token_ids = llm_inputs.get("prompt_token_ids")
|
||||||
prompt = tokenizer.decode(token_ids)
|
prompt = tokenizer.decode(token_ids)
|
||||||
image_processor = cached_get_image_processor(model_config.tokenizer)
|
|
||||||
|
|
||||||
pattern = "(<image>./</image>)"
|
pattern = "(<image>./</image>)"
|
||||||
image = multi_modal_data["image"]
|
images = multi_modal_data["image"]
|
||||||
|
if isinstance(images, Image.Image):
|
||||||
|
images = [images]
|
||||||
image_tags = re.findall(pattern, prompt)
|
image_tags = re.findall(pattern, prompt)
|
||||||
|
|
||||||
if len(image_tags) == 0:
|
if len(image_tags) == 0:
|
||||||
new_token_ids = token_ids
|
new_token_ids = token_ids
|
||||||
new_prompt = prompt
|
new_prompt = prompt
|
||||||
else:
|
else:
|
||||||
if len(image_tags) > 1:
|
|
||||||
logger.warning("Multiple image input is not supported yet, "
|
|
||||||
"so any extra image tokens will be treated "
|
|
||||||
"as plain text.")
|
|
||||||
|
|
||||||
text_chunks = prompt.split(pattern)
|
text_chunks = prompt.split(pattern)
|
||||||
new_prompt = (text_chunks[0] +
|
new_prompt_chunks: List[str] = []
|
||||||
image_processor.get_slice_image_placeholder(image.size) +
|
for i in range(len(images)):
|
||||||
"".join(text_chunks[1:]))
|
new_prompt_chunks += [
|
||||||
|
text_chunks[i],
|
||||||
|
get_placeholder(images[i].size, i)
|
||||||
|
]
|
||||||
|
new_prompt_chunks.append(text_chunks[-1])
|
||||||
|
new_prompt = "".join(new_prompt_chunks)
|
||||||
new_token_ids = tokenizer.encode(new_prompt)
|
new_token_ids = tokenizer.encode(new_prompt)
|
||||||
|
|
||||||
llm_inputs = LLMInputs(
|
llm_inputs = LLMInputs(
|
||||||
@ -478,14 +499,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
if not hasattr(self.config, "version"):
|
self.version = get_version_by_config(self.config)
|
||||||
if self.config.hidden_size == 2304 and self.config.query_num == 64:
|
|
||||||
self.version = (2, 0)
|
|
||||||
else:
|
|
||||||
self.version = (2, 5)
|
|
||||||
else:
|
|
||||||
self.version = str(self.config.version).split(".")
|
|
||||||
self.version = tuple([int(x) for x in self.version])
|
|
||||||
self.llm = self.init_llm(config, cache_config, quant_config)
|
self.llm = self.init_llm(config, cache_config, quant_config)
|
||||||
self.vpm = self.init_vision_module()
|
self.vpm = self.init_vision_module()
|
||||||
param_dtype = torch.get_default_dtype()
|
param_dtype = torch.get_default_dtype()
|
||||||
|
@ -113,7 +113,7 @@ class ImagePlugin(MultiModalPlugin):
|
|||||||
def _default_input_mapper(self, ctx: InputContext,
|
def _default_input_mapper(self, ctx: InputContext,
|
||||||
data: object) -> MultiModalInputs:
|
data: object) -> MultiModalInputs:
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
if isinstance(data, Image.Image):
|
if isinstance(data, (Image.Image, list)):
|
||||||
image_processor = self._get_hf_image_processor(model_config)
|
image_processor = self._get_hf_image_processor(model_config)
|
||||||
if image_processor is None:
|
if image_processor is None:
|
||||||
raise RuntimeError("No HuggingFace processor is available "
|
raise RuntimeError("No HuggingFace processor is available "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user