[Model] support input embeddings for qwen2vl (#8856)

This commit is contained in:
whyiug 2024-09-30 11:16:10 +08:00 committed by GitHub
parent f13a07b1f8
commit e01ab595d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 135 additions and 70 deletions

View File

@ -281,7 +281,7 @@ Multimodal Language Models
- -
* - :code:`Qwen2VLForConditionalGeneration` * - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL - Qwen2-VL
- Image\ :sup:`+` / Video\ :sup:`+` - Image\ :sup:`E+` / Video\ :sup:`+`
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
- -
* - :code:`UltravoxModel` * - :code:`UltravoxModel`

View File

@ -57,6 +57,23 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
"multi_modal_data": {"image": image_embeds}, "multi_modal_data": {"image": image_embeds},
}) })
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
# Inference with image embeddings as input with additional parameters
# Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding.
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3)
mm_data['image'] = {
"image_embeds": image_embeds,
"image_grid_thw": image_grid_thw,
}
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": mm_data,
})
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
print(generated_text) print(generated_text)

View File

@ -23,8 +23,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Union) Tuple, Type, TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -76,19 +76,31 @@ logger = init_logger(__name__)
# === Vision Inputs === # # === Vision Inputs === #
class Qwen2VLImageInputs(TypedDict): class Qwen2VLImagePixelInputs(TypedDict):
pixel_values: torch.Tensor type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: """Shape:
`(num_patches, num_channels * patch_size * patch_size)` `(num_patches, num_channels * patch_size * patch_size)`
""" """
image_grid_thw: torch.Tensor image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)` """Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format. This should be in `(grid_t, grid_h, grid_w)` format.
""" """
class Qwen2VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
Qwen2VLImageEmbeddingInputs]
class Qwen2VLVideoInputs(TypedDict): class Qwen2VLVideoInputs(TypedDict):
pixel_values_videos: torch.Tensor pixel_values_videos: torch.Tensor
"""Shape: """Shape:
@ -567,6 +579,11 @@ def mm_input_mapper_for_qwen2_vl(
data_type_key: str, data_type_key: str,
) -> MultiModalInputs: ) -> MultiModalInputs:
"""Input mapper for Qwen2-VL.""" """Input mapper for Qwen2-VL."""
if data_type_key == "image" and isinstance(data, dict):
return MultiModalInputs({
"image_embeds": data.get("image_embeds"),
"image_grid_thw": data.get("image_grid_thw"),
})
model_config = ctx.model_config model_config = ctx.model_config
image_processor = cached_get_image_processor( image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code) model_config.model, trust_remote_code=model_config.trust_remote_code)
@ -739,6 +756,48 @@ def _get_llm_num_vision_tokens(
return llm_num_vision_tokens return llm_num_vision_tokens
def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
data_type_key: str, image_processor: Any,
prompt_token_ids: List[int]) -> List[int]:
"""
Expand pad tokens for multi-modal inputs (e.g., images or videos).
Args:
inputs (list): The multi-modal inputs (e.g., images or videos).
token_id (int): The token ID used to represent the multi-modal input.
make_batched_fn (Callable): A function to batch the inputs.
data_type_key (str): The type of the multi-modal input.
image_processor (Any): The image processor used to process the inputs.
prompt_token_ids (List[int]): The list of token IDs in the prompt.
Returns:
List[int]: The list of token IDs for the multi-modal inputs.
"""
indices = [
idx for idx, token in enumerate(prompt_token_ids) if token == token_id
]
inputs = make_batched_fn(inputs)
assert len(indices) == len(inputs)
prompt_token_ids_with_data = []
for cnt, data in enumerate(inputs):
num_tokens = _get_llm_num_vision_tokens(
[data] if data_type_key == "image" else data,
data_type_key=data_type_key,
image_processor=image_processor,
)
if cnt == 0:
end_idx = indices[cnt]
non_data_tokens = prompt_token_ids[:end_idx]
else:
non_data_tokens = prompt_token_ids[indices[cnt - 1] +
1:indices[cnt]]
prompt_token_ids_with_data.extend(non_data_tokens)
prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens))
prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:])
return prompt_token_ids_with_data
def input_processor_for_qwen2_vl(ctx: InputContext, def input_processor_for_qwen2_vl(ctx: InputContext,
llm_inputs: LLMInputs) -> LLMInputs: llm_inputs: LLMInputs) -> LLMInputs:
multi_modal_data = llm_inputs.get("multi_modal_data", None) multi_modal_data = llm_inputs.get("multi_modal_data", None)
@ -775,62 +834,38 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
)["input_ids"] )["input_ids"]
# Expand image pad tokens. # Expand image pad tokens.
if image_inputs is not None: if image_inputs is not None:
if isinstance(image_inputs, dict):
prompt_token_ids_with_image = []
image_indices = [ image_indices = [
idx for idx, token in enumerate(prompt_token_ids) idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.image_token_id if token == hf_config.image_token_id
] ]
image_inputs = make_batched_images(image_inputs) image_cnt = len(image_indices)
assert len(image_indices) == len(image_inputs) embed_dim = image_inputs.get('image_embeds').size(0)
assert embed_dim % image_cnt == 0
prompt_token_ids_with_image = [] num_pad_tokens = embed_dim // image_cnt
for image_cnt, image in enumerate(image_inputs): for idx, token in enumerate(prompt_token_ids):
num_image_tokens = _get_llm_num_vision_tokens( if idx in image_indices:
[image], prompt_token_ids_with_image.extend([token] *
data_type_key="image", num_pad_tokens)
image_processor=image_processor,
)
if image_cnt == 0:
non_image_tokens = prompt_token_ids[:image_indices[image_cnt]]
else: else:
non_image_tokens = prompt_token_ids[image_indices[image_cnt - prompt_token_ids_with_image.append(token)
1] +
1:image_indices[image_cnt]]
prompt_token_ids_with_image.extend(non_image_tokens)
prompt_token_ids_with_image.extend(
hf_config.image_token_id for _ in range(num_image_tokens))
prompt_token_ids_with_image.extend(prompt_token_ids[image_indices[-1] +
1:])
prompt_token_ids = prompt_token_ids_with_image prompt_token_ids = prompt_token_ids_with_image
# Expand video pad tokens.
if video_inputs is not None:
video_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.video_token_id
]
video_inputs = make_batched_videos(video_inputs)
assert len(video_indices) == len(video_inputs)
prompt_token_ids_with_video = []
for video_cnt, video in enumerate(video_inputs):
num_video_tokens = _get_llm_num_vision_tokens(
video,
data_type_key="video",
image_processor=image_processor,
)
if video_cnt == 0:
non_video_tokens = prompt_token_ids[:video_indices[video_cnt]]
else: else:
non_video_tokens = prompt_token_ids[video_indices[video_cnt - prompt_token_ids = _expand_pad_tokens(image_inputs,
1] + hf_config.image_token_id,
1:video_indices[video_cnt]] make_batched_images, "image",
prompt_token_ids_with_video.extend(non_video_tokens) image_processor,
prompt_token_ids_with_video.extend( prompt_token_ids)
hf_config.video_token_id for _ in range(num_video_tokens))
prompt_token_ids_with_video.extend(prompt_token_ids[video_indices[-1] + if video_inputs is not None:
1:]) prompt_token_ids = _expand_pad_tokens(video_inputs,
prompt_token_ids = prompt_token_ids_with_video hf_config.video_token_id,
make_batched_videos, "video",
image_processor,
prompt_token_ids)
return LLMInputs( return LLMInputs(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
@ -910,11 +945,13 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Qwen2VLImageInputs]: self, **kwargs: object) -> Optional[Qwen2VLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None) image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None: if pixel_values is None and image_embeds is None:
return None return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor( pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values") pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw = self._validate_and_reshape_mm_tensor(
@ -924,9 +961,17 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of image pixel values. " raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
return Qwen2VLImageInputs(pixel_values=pixel_values, return Qwen2VLImagePixelInputs(type="pixel_values",
data=pixel_values,
image_grid_thw=image_grid_thw) image_grid_thw=image_grid_thw)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Qwen2VLImageEmbeddingInputs(type="image_embeds",
data=image_embeds)
def _parse_and_validate_video_input( def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]: self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None) pixel_values_videos = kwargs.pop("pixel_values_videos", None)
@ -947,7 +992,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
def _process_image_input(self, def _process_image_input(self,
image_input: Qwen2VLImageInputs) -> torch.Tensor: image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype) if image_input["type"] == "image_embeds":
return image_input["data"].type(self.visual.dtype)
pixel_values = image_input["data"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, image_embeds = self.visual(pixel_values,
grid_thw=image_input["image_grid_thw"]) grid_thw=image_input["image_grid_thw"])
return image_embeds return image_embeds