[Model] Multi-input support for LLaVA (#8238)

This commit is contained in:
Cyrus Leung 2024-09-07 10:57:24 +08:00 committed by GitHub
parent 41e95c5247
commit 2f707fcb35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 176 additions and 45 deletions

View File

@ -219,7 +219,7 @@ Multimodal Language Models
- -
* - :code:`LlavaForConditionalGeneration` * - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5 - LLaVA-1.5
- Image\ :sup:`E` - Image\ :sup:`E+`
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
- -
* - :code:`LlavaNextForConditionalGeneration` * - :code:`LlavaNextForConditionalGeneration`
@ -227,6 +227,11 @@ Multimodal Language Models
- Image\ :sup:`E+` - Image\ :sup:`E+`
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
- -
* - :code:`MiniCPMV`
- MiniCPM-V
- Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code:`PaliGemmaForConditionalGeneration` * - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma - PaliGemma
- Image\ :sup:`E` - Image\ :sup:`E`
@ -237,14 +242,9 @@ Multimodal Language Models
- Image\ :sup:`E+` - Image\ :sup:`E+`
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
- -
* - :code:`MiniCPMV`
- MiniCPM-V
- Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code:`QWenLMHeadModel` * - :code:`QWenLMHeadModel`
- Qwen - Qwen-VL
- Image - Image\ :sup:`E`
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
- -
* - :code:`UltravoxModel` * - :code:`UltravoxModel`

View File

@ -278,7 +278,7 @@ class HfRunner:
def generate( def generate(
self, self,
prompts: List[str], prompts: List[str],
images: Optional[List[Image.Image]] = None, images: Optional[PromptImageInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]: ) -> List[Tuple[List[List[int]], List[str]]]:
if images: if images:
@ -314,7 +314,7 @@ class HfRunner:
self, self,
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional[List[Image.Image]] = None, images: Optional[PromptImageInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts, outputs = self.generate(prompts,
@ -351,7 +351,7 @@ class HfRunner:
self, self,
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional[List[Image.Image]] = None, images: Optional[PromptImageInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[List[torch.Tensor]]: ) -> List[List[torch.Tensor]]:
all_logprobs: List[List[torch.Tensor]] = [] all_logprobs: List[List[torch.Tensor]] = []
@ -433,8 +433,8 @@ class HfRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
images: Optional[List[Image.Image]] = None, images: Optional[PromptImageInput] = None,
audios: Optional[List[Tuple[np.ndarray, int]]] = None, audios: Optional[PromptAudioInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = [] all_logprobs: List[List[Dict[int, float]]] = []
@ -671,7 +671,7 @@ class VllmRunner:
self, self,
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional[List[Image.Image]] = None, images: Optional[PromptImageInput] = None,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images) outputs = self.generate(prompts, greedy_params, images=images)

View File

@ -35,9 +35,11 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
if model.startswith("llava-hf/llava-1.5"): if model.startswith("llava-hf/llava-1.5"):
from ..models.test_llava import models, run_test from ..models.test_llava import models, run_test
elif model.startswith("llava-hf/llava-v1.6"): elif model.startswith("llava-hf/llava-v1.6"):
from ..models.test_llava_next import models, run_test from ..models.test_llava_next import run_test # type: ignore[no-redef]
from ..models.test_llava_next import models
elif model.startswith("facebook/chameleon"): elif model.startswith("facebook/chameleon"):
from ..models.test_chameleon import models, run_test from ..models.test_chameleon import run_test # type: ignore[no-redef]
from ..models.test_chameleon import models
else: else:
raise NotImplementedError(f"Unsupported model: {model}") raise NotImplementedError(f"Unsupported model: {model}")

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type, overload
import pytest import pytest
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
@ -8,11 +8,14 @@ from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from .utils import check_logprobs_close from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm pytestmark = pytest.mark.vlm
_LIMIT_IMAGE_PER_PROMPT = 4
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "stop_sign":
"USER: <image>\nWhat's the content of the image?\nASSISTANT:", "USER: <image>\nWhat's the content of the image?\nASSISTANT:",
@ -52,6 +55,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
return hf_output_ids, hf_output_str, out_logprobs return hf_output_ids, hf_output_str, out_logprobs
@overload
def run_test( def run_test(
hf_runner: Type[HfRunner], hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner], vllm_runner: Type[VllmRunner],
@ -64,6 +68,78 @@ def run_test(
num_logprobs: int, num_logprobs: int,
tensor_parallel_size: int, tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None, distributed_executor_backend: Optional[str] = None,
):
...
@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
images = [asset.pil_image for asset in image_assets]
if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[prompt for _ in sizes],
[image.resize(size) for size in sizes],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
_run_test(hf_runner,
vllm_runner,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend)
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
): ):
"""Inference result should be the same between hf and vllm. """Inference result should be the same between hf and vllm.
@ -85,13 +161,6 @@ def run_test(
else: else:
mantis_processor = None mantis_processor = None
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF. # NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization. # vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
@ -100,15 +169,18 @@ def run_test(
# max_model_len should be greater than image_feature_size # max_model_len should be greater than image_feature_size
with vllm_runner(model, with vllm_runner(model,
dtype=dtype, dtype=dtype,
max_model_len=4096,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model: enforce_eager=True,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_image = [ vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts, vllm_model.generate_greedy_logprobs(prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=images) images=images)
for prompts, images in inputs_per_image for prompts, images in inputs
] ]
if mantis_processor is not None: if mantis_processor is not None:
@ -131,7 +203,7 @@ def run_test(
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=images) images=images)
for prompts, images in inputs_per_image for prompts, images in inputs
] ]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
@ -181,6 +253,51 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
) )
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
inputs = [(
[
"USER: <image><image>\nDescribe 2 images.\nASSISTANT:",
"USER: <image><image>\nDescribe 2 images.\nASSISTANT:",
"USER: <image><image><image><image>\nDescribe 4 images.\nASSISTANT:", # noqa: E501
"USER: <image>\nWhat is the season?\nASSISTANT:",
],
[
[stop_sign, cherry_blossom],
# Images with different sizes and aspect-ratios
[
rescale_image_size(stop_sign, 0.1),
stop_sign,
],
[
stop_sign,
rescale_image_size(stop_sign, 0.25),
cherry_blossom.resize((183, 488)),
cherry_blossom.resize((488, 183))
],
cherry_blossom,
])]
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_length_too_short(vllm_runner, image_assets, model): def test_context_length_too_short(vllm_runner, image_assets, model):
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]

View File

@ -105,7 +105,7 @@ def input_processor_for_clip(
if isinstance(image_data, Image.Image): if isinstance(image_data, Image.Image):
image_feature_size = get_clip_image_feature_size(hf_config) image_feature_size = get_clip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
else: else:

View File

@ -209,7 +209,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
image_feature_size = num_blocks * num_patches image_feature_size = num_blocks * num_patches
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")

View File

@ -4,6 +4,7 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
@ -16,6 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip, from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens, dummy_seq_data_for_clip, get_max_clip_image_tokens,
@ -24,7 +26,7 @@ from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens, dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip) input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model, from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
@ -133,7 +135,18 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
hf_config = ctx.get_hf_config(LlavaConfig) hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
image_feature_size = get_max_llava_image_tokens(ctx) image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_max_llava_image_tokens(ctx)
elif is_list_of(image_data, Image.Image):
image_feature_size = [get_max_llava_image_tokens(ctx)
] * len(image_data)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip( return input_processor_for_clip(
@ -230,29 +243,24 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
return None return None
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
) )
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return LlavaImageEmbeddingInputs( return LlavaImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=flatten_bn(image_embeds, concat=True),
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")

View File

@ -234,7 +234,9 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
for img in image_data for img in image_data
] ]
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")

View File

@ -424,7 +424,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
input_width=w, input_width=w,
input_height=h)) input_height=h))
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")

View File

@ -110,7 +110,7 @@ def input_processor_for_siglip(
if isinstance(image_data, Image.Image): if isinstance(image_data, Image.Image):
image_feature_size = get_siglip_image_feature_size(hf_config) image_feature_size = get_siglip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
else: else: