[Model] support input embeddings for qwen2vl (#8856)
This commit is contained in:
parent
f13a07b1f8
commit
e01ab595d8
@ -281,7 +281,7 @@ Multimodal Language Models
|
||||
-
|
||||
* - :code:`Qwen2VLForConditionalGeneration`
|
||||
- Qwen2-VL
|
||||
- Image\ :sup:`+` / Video\ :sup:`+`
|
||||
- Image\ :sup:`E+` / Video\ :sup:`+`
|
||||
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
|
||||
-
|
||||
* - :code:`UltravoxModel`
|
||||
|
@ -57,6 +57,23 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
|
||||
"multi_modal_data": {"image": image_embeds},
|
||||
})
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
# Inference with image embeddings as input with additional parameters
|
||||
# Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding.
|
||||
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
|
||||
image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3)
|
||||
mm_data['image'] = {
|
||||
"image_embeds": image_embeds,
|
||||
"image_grid_thw": image_grid_thw,
|
||||
}
|
||||
outputs = llm.generate({
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": mm_data,
|
||||
})
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
@ -23,8 +23,8 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
from functools import lru_cache, partial
|
||||
from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict,
|
||||
Union)
|
||||
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
|
||||
Tuple, Type, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -76,19 +76,31 @@ logger = init_logger(__name__)
|
||||
# === Vision Inputs === #
|
||||
|
||||
|
||||
class Qwen2VLImageInputs(TypedDict):
|
||||
pixel_values: torch.Tensor
|
||||
class Qwen2VLImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches, num_channels * patch_size * patch_size)`
|
||||
"""
|
||||
|
||||
image_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_images, 3)`
|
||||
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
|
||||
|
||||
class Qwen2VLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
|
||||
Qwen2VLImageEmbeddingInputs]
|
||||
|
||||
|
||||
class Qwen2VLVideoInputs(TypedDict):
|
||||
pixel_values_videos: torch.Tensor
|
||||
"""Shape:
|
||||
@ -567,6 +579,11 @@ def mm_input_mapper_for_qwen2_vl(
|
||||
data_type_key: str,
|
||||
) -> MultiModalInputs:
|
||||
"""Input mapper for Qwen2-VL."""
|
||||
if data_type_key == "image" and isinstance(data, dict):
|
||||
return MultiModalInputs({
|
||||
"image_embeds": data.get("image_embeds"),
|
||||
"image_grid_thw": data.get("image_grid_thw"),
|
||||
})
|
||||
model_config = ctx.model_config
|
||||
image_processor = cached_get_image_processor(
|
||||
model_config.model, trust_remote_code=model_config.trust_remote_code)
|
||||
@ -739,6 +756,48 @@ def _get_llm_num_vision_tokens(
|
||||
return llm_num_vision_tokens
|
||||
|
||||
|
||||
def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
|
||||
data_type_key: str, image_processor: Any,
|
||||
prompt_token_ids: List[int]) -> List[int]:
|
||||
"""
|
||||
Expand pad tokens for multi-modal inputs (e.g., images or videos).
|
||||
|
||||
Args:
|
||||
inputs (list): The multi-modal inputs (e.g., images or videos).
|
||||
token_id (int): The token ID used to represent the multi-modal input.
|
||||
make_batched_fn (Callable): A function to batch the inputs.
|
||||
data_type_key (str): The type of the multi-modal input.
|
||||
image_processor (Any): The image processor used to process the inputs.
|
||||
prompt_token_ids (List[int]): The list of token IDs in the prompt.
|
||||
|
||||
Returns:
|
||||
List[int]: The list of token IDs for the multi-modal inputs.
|
||||
"""
|
||||
indices = [
|
||||
idx for idx, token in enumerate(prompt_token_ids) if token == token_id
|
||||
]
|
||||
inputs = make_batched_fn(inputs)
|
||||
assert len(indices) == len(inputs)
|
||||
|
||||
prompt_token_ids_with_data = []
|
||||
for cnt, data in enumerate(inputs):
|
||||
num_tokens = _get_llm_num_vision_tokens(
|
||||
[data] if data_type_key == "image" else data,
|
||||
data_type_key=data_type_key,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
if cnt == 0:
|
||||
end_idx = indices[cnt]
|
||||
non_data_tokens = prompt_token_ids[:end_idx]
|
||||
else:
|
||||
non_data_tokens = prompt_token_ids[indices[cnt - 1] +
|
||||
1:indices[cnt]]
|
||||
prompt_token_ids_with_data.extend(non_data_tokens)
|
||||
prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens))
|
||||
prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:])
|
||||
return prompt_token_ids_with_data
|
||||
|
||||
|
||||
def input_processor_for_qwen2_vl(ctx: InputContext,
|
||||
llm_inputs: LLMInputs) -> LLMInputs:
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data", None)
|
||||
@ -775,62 +834,38 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
|
||||
)["input_ids"]
|
||||
|
||||
# Expand image pad tokens.
|
||||
|
||||
if image_inputs is not None:
|
||||
if isinstance(image_inputs, dict):
|
||||
prompt_token_ids_with_image = []
|
||||
image_indices = [
|
||||
idx for idx, token in enumerate(prompt_token_ids)
|
||||
if token == hf_config.image_token_id
|
||||
]
|
||||
image_inputs = make_batched_images(image_inputs)
|
||||
assert len(image_indices) == len(image_inputs)
|
||||
|
||||
prompt_token_ids_with_image = []
|
||||
for image_cnt, image in enumerate(image_inputs):
|
||||
num_image_tokens = _get_llm_num_vision_tokens(
|
||||
[image],
|
||||
data_type_key="image",
|
||||
image_processor=image_processor,
|
||||
)
|
||||
if image_cnt == 0:
|
||||
non_image_tokens = prompt_token_ids[:image_indices[image_cnt]]
|
||||
image_cnt = len(image_indices)
|
||||
embed_dim = image_inputs.get('image_embeds').size(0)
|
||||
assert embed_dim % image_cnt == 0
|
||||
num_pad_tokens = embed_dim // image_cnt
|
||||
for idx, token in enumerate(prompt_token_ids):
|
||||
if idx in image_indices:
|
||||
prompt_token_ids_with_image.extend([token] *
|
||||
num_pad_tokens)
|
||||
else:
|
||||
non_image_tokens = prompt_token_ids[image_indices[image_cnt -
|
||||
1] +
|
||||
1:image_indices[image_cnt]]
|
||||
prompt_token_ids_with_image.extend(non_image_tokens)
|
||||
prompt_token_ids_with_image.extend(
|
||||
hf_config.image_token_id for _ in range(num_image_tokens))
|
||||
prompt_token_ids_with_image.extend(prompt_token_ids[image_indices[-1] +
|
||||
1:])
|
||||
prompt_token_ids_with_image.append(token)
|
||||
prompt_token_ids = prompt_token_ids_with_image
|
||||
|
||||
# Expand video pad tokens.
|
||||
if video_inputs is not None:
|
||||
video_indices = [
|
||||
idx for idx, token in enumerate(prompt_token_ids)
|
||||
if token == hf_config.video_token_id
|
||||
]
|
||||
video_inputs = make_batched_videos(video_inputs)
|
||||
assert len(video_indices) == len(video_inputs)
|
||||
|
||||
prompt_token_ids_with_video = []
|
||||
for video_cnt, video in enumerate(video_inputs):
|
||||
num_video_tokens = _get_llm_num_vision_tokens(
|
||||
video,
|
||||
data_type_key="video",
|
||||
image_processor=image_processor,
|
||||
)
|
||||
if video_cnt == 0:
|
||||
non_video_tokens = prompt_token_ids[:video_indices[video_cnt]]
|
||||
else:
|
||||
non_video_tokens = prompt_token_ids[video_indices[video_cnt -
|
||||
1] +
|
||||
1:video_indices[video_cnt]]
|
||||
prompt_token_ids_with_video.extend(non_video_tokens)
|
||||
prompt_token_ids_with_video.extend(
|
||||
hf_config.video_token_id for _ in range(num_video_tokens))
|
||||
prompt_token_ids_with_video.extend(prompt_token_ids[video_indices[-1] +
|
||||
1:])
|
||||
prompt_token_ids = prompt_token_ids_with_video
|
||||
prompt_token_ids = _expand_pad_tokens(image_inputs,
|
||||
hf_config.image_token_id,
|
||||
make_batched_images, "image",
|
||||
image_processor,
|
||||
prompt_token_ids)
|
||||
|
||||
if video_inputs is not None:
|
||||
prompt_token_ids = _expand_pad_tokens(video_inputs,
|
||||
hf_config.video_token_id,
|
||||
make_batched_videos, "video",
|
||||
image_processor,
|
||||
prompt_token_ids)
|
||||
|
||||
return LLMInputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
@ -910,11 +945,13 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Qwen2VLImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
image_grid_thw = kwargs.pop("image_grid_thw", None)
|
||||
|
||||
if pixel_values is None:
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values, "image pixel values")
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
@ -924,9 +961,17 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
raise ValueError("Incorrect type of image pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return Qwen2VLImageInputs(pixel_values=pixel_values,
|
||||
return Qwen2VLImagePixelInputs(type="pixel_values",
|
||||
data=pixel_values,
|
||||
image_grid_thw=image_grid_thw)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return Qwen2VLImageEmbeddingInputs(type="image_embeds",
|
||||
data=image_embeds)
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]:
|
||||
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
|
||||
@ -947,7 +992,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
def _process_image_input(self,
|
||||
image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"].type(self.visual.dtype)
|
||||
|
||||
pixel_values = image_input["data"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values,
|
||||
grid_thw=image_input["image_grid_thw"])
|
||||
return image_embeds
|
||||
|
Loading…
x
Reference in New Issue
Block a user