[Model] Add Idefics3 support (#9767)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: B-201 <Joy25810@foxmail.com> Co-authored-by: B-201 <Joy25810@foxmail.com>
This commit is contained in:
parent
2003cc3513
commit
a5bba7d234
@ -446,6 +446,12 @@ Text Generation
|
||||
- :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`Idefics3ForConditionalGeneration`
|
||||
- Idefics3
|
||||
- T + I
|
||||
- :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc.
|
||||
-
|
||||
-
|
||||
* - :code:`InternVLChatModel`
|
||||
- InternVL2
|
||||
- T + I\ :sup:`E+`
|
||||
|
@ -377,6 +377,22 @@ def run_glm4v(question: str, modality: str):
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# Idefics3-8B-Llama3
|
||||
def run_idefics3(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
enforce_eager=True)
|
||||
prompt = (
|
||||
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
|
||||
)
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"llava": run_llava,
|
||||
"llava-next": run_llava_next,
|
||||
@ -397,6 +413,7 @@ model_example_map = {
|
||||
"mllama": run_mllama,
|
||||
"molmo": run_molmo,
|
||||
"glm4v": run_glm4v,
|
||||
"idefics3": run_idefics3,
|
||||
}
|
||||
|
||||
|
||||
|
@ -290,6 +290,30 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_idefics3(question, image_urls: List[str]) -> ModelRequestData:
|
||||
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
|
||||
|
||||
# The configuration below has been confirmed to launch on a single L40 GPU.
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=16,
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
stop_token_ids=None,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"phi3_v": load_phi3v,
|
||||
"h2ovl_chat": load_h2onvl,
|
||||
@ -298,6 +322,7 @@ model_example_map = {
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
"qwen_vl_chat": load_qwenvl_chat,
|
||||
"mllama": load_mllama,
|
||||
"idefics3": load_idefics3,
|
||||
}
|
||||
|
||||
|
||||
|
@ -327,6 +327,22 @@ VLM_TEST_SETTINGS = {
|
||||
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
|
||||
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
|
||||
),
|
||||
"idefics3": VLMTestInfo(
|
||||
models=["HuggingFaceM4/Idefics3-8B-Llama3"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<image>",
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
transformers.__version__ < "4.46.0",
|
||||
reason="Model introduced in HF >= 4.46.0"
|
||||
),
|
||||
large_gpu_mark(min_gb=48),
|
||||
],
|
||||
),
|
||||
### Tensor parallel / multi-gpu broadcast tests
|
||||
"broadcast-chameleon": VLMTestInfo(
|
||||
models=["facebook/chameleon-7b"],
|
||||
|
@ -187,6 +187,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
if model_type == "molmo":
|
||||
return ""
|
||||
if model_type == "idefics3":
|
||||
return "<image>"
|
||||
|
||||
raise TypeError(f"Unknown {modality} model type: {model_type}")
|
||||
elif modality == "audio":
|
||||
|
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Idefics2 model."""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -29,6 +29,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class Idefics2VisionEmbeddings(nn.Module):
|
||||
@ -329,3 +330,25 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
encoder_outputs = self.encoder(hidden_states)
|
||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||
return last_hidden_state
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
632
vllm/model_executor/models/idefics3.py
Normal file
632
vllm/model_executor/models/idefics3.py
Normal file
@ -0,0 +1,632 @@
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
# Temporary solution for transformers below 4.46.0.
|
||||
from transformers import PretrainedConfig as Idefics3Config
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
# yapf: disable
|
||||
from .idefics2_vision_model import (
|
||||
Idefics2VisionTransformer as Idefics3VisionTransformer)
|
||||
# yapf: enable
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .llama import LlamaModel
|
||||
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Idefics3ImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images, num_channels, height, width)`
|
||||
"""
|
||||
rows: List[int]
|
||||
cols: List[int]
|
||||
pixel_attention_mask: Optional[torch.BoolTensor]
|
||||
|
||||
|
||||
class Idefics3ImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
|
||||
|
||||
|
||||
def input_mapper_for_idefics3(
|
||||
ctx: InputContext,
|
||||
data: object,
|
||||
):
|
||||
model_config = ctx.model_config
|
||||
image_processor = cached_get_image_processor(
|
||||
model_config.model, trust_remote_code=model_config.trust_remote_code)
|
||||
if image_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available "
|
||||
"to process the image object")
|
||||
|
||||
if isinstance(data, Image.Image):
|
||||
images = [[data]]
|
||||
elif is_list_of(data, Image.Image):
|
||||
images = [data]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(data)}")
|
||||
|
||||
try:
|
||||
batch_data = image_processor(images,
|
||||
return_tensors="pt",
|
||||
return_row_col_info=True).data
|
||||
except Exception:
|
||||
logger.error("Failed to process image (%s)", data)
|
||||
raise
|
||||
|
||||
return MultiModalInputs(batch_data)
|
||||
|
||||
|
||||
def _resize_output_size(height: int,
|
||||
width: int,
|
||||
max_len: Optional[int] = None,
|
||||
min_len: Optional[int] = 1,
|
||||
max_size: Optional[int] = None) -> Tuple[int, int]:
|
||||
# Set default value for max_len if not provided
|
||||
max_len = max(height, width) if max_len is None else max_len
|
||||
aspect_ratio = width / height
|
||||
|
||||
# Handle the maximum size constraint
|
||||
if max_size is not None:
|
||||
max_len = min(max_len, max_size)
|
||||
|
||||
# Adjust dimensions according to the aspect ratio
|
||||
if width >= height:
|
||||
width = max_len
|
||||
height = int(width / aspect_ratio)
|
||||
else:
|
||||
height = max_len
|
||||
width = int(height * aspect_ratio)
|
||||
|
||||
# Ensure both width and height are even (if needed)
|
||||
height += 1 if height % 2 != 0 else 0
|
||||
width += 1 if width % 2 != 0 else 0
|
||||
|
||||
# Ensure dimensions are not smaller than the minimum length
|
||||
height = max(height, min_len)
|
||||
width = max(width, min_len)
|
||||
|
||||
return height, width
|
||||
|
||||
|
||||
def _get_resize_output_image_size(
|
||||
image_size: Tuple[int, int],
|
||||
resolution_max_side: int,
|
||||
max_image_size: int = 1820,
|
||||
) -> Tuple[int, int]:
|
||||
if resolution_max_side > max_image_size:
|
||||
raise ValueError(
|
||||
"`resolution_max_side` cannot be larger than `max_image_size`")
|
||||
|
||||
height, width = image_size
|
||||
|
||||
# Find the output size, when rescaling the longest edge to max_len and
|
||||
# preserving the aspect ratio
|
||||
height, width = _resize_output_size(height,
|
||||
width,
|
||||
max_len=resolution_max_side)
|
||||
|
||||
return height, width
|
||||
|
||||
|
||||
def _prompt_split_image(image_seq_len: int, image_rows: int, image_cols: int,
|
||||
fake_token_around_image: str, image_token: str,
|
||||
global_img_token: str) -> str:
|
||||
"""
|
||||
Prompt with expanded image tokens for when the image is split
|
||||
into patches.
|
||||
"""
|
||||
text_split_images = ""
|
||||
for n_h in range(image_rows):
|
||||
for n_w in range(image_cols):
|
||||
text_split_images += (fake_token_around_image +
|
||||
f"<row_{n_h + 1}_col_{n_w + 1}>" +
|
||||
image_token * image_seq_len)
|
||||
text_split_images += "\n"
|
||||
|
||||
text_split_images += "\n" + _prompt_single_image(
|
||||
image_seq_len=image_seq_len,
|
||||
fake_token_around_image=fake_token_around_image,
|
||||
image_token=image_token,
|
||||
global_img_token=global_img_token)
|
||||
return text_split_images
|
||||
|
||||
|
||||
def _prompt_single_image(image_seq_len: int, fake_token_around_image: str,
|
||||
image_token: str, global_img_token: str):
|
||||
"""Prompt with expanded image tokens for a single image."""
|
||||
return (fake_token_around_image + global_img_token +
|
||||
image_token * image_seq_len + fake_token_around_image)
|
||||
|
||||
|
||||
def _get_image_prompt_string(image_rows: int, image_cols: int,
|
||||
image_seq_len: int, fake_token_around_image: str,
|
||||
image_token: str, global_img_token: str):
|
||||
if image_rows == 0 and image_cols == 0:
|
||||
return _prompt_single_image(
|
||||
image_seq_len=image_seq_len,
|
||||
fake_token_around_image=fake_token_around_image,
|
||||
image_token=image_token,
|
||||
global_img_token=global_img_token,
|
||||
)
|
||||
return _prompt_split_image(image_seq_len, image_rows, image_cols,
|
||||
fake_token_around_image, image_token,
|
||||
global_img_token)
|
||||
|
||||
|
||||
def input_processor_for_idefics3(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
processor = cached_get_processor(model_config.model)
|
||||
image_processor = processor.image_processor
|
||||
tokenizer = processor.tokenizer
|
||||
size = image_processor.size['longest_edge']
|
||||
max_image_size = image_processor.max_image_size['longest_edge']
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
image_list = [image_data]
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_list = image_data
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
image_rows = []
|
||||
image_cols = []
|
||||
for image in image_list:
|
||||
height, width = _get_resize_output_image_size(image.size, size)
|
||||
|
||||
rows = math.ceil(height / max_image_size)
|
||||
cols = math.ceil(width / max_image_size)
|
||||
image_rows.append(rows)
|
||||
image_cols.append(cols)
|
||||
image_rows = [image_rows]
|
||||
image_cols = [image_cols]
|
||||
|
||||
n_images_in_text = []
|
||||
|
||||
text = inputs.get("prompt")
|
||||
if text is not None:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, "
|
||||
"or a list of strings")
|
||||
|
||||
fake_image_token = processor.fake_image_token.content
|
||||
image_token = processor.image_token.content
|
||||
global_img_token = processor.global_image_tag
|
||||
|
||||
prompt_strings = []
|
||||
for sample, sample_rows, sample_cols in zip(text, image_rows,
|
||||
image_cols):
|
||||
n_images_in_text.append(sample.count(image_token))
|
||||
|
||||
# Replace the image token with fake tokens around the expanded
|
||||
# image token sequence of length `image_seq_len`
|
||||
image_prompt_strings = []
|
||||
for n_rows, n_cols in zip(sample_rows, sample_cols):
|
||||
image_prompt_string = _get_image_prompt_string(
|
||||
n_rows,
|
||||
n_cols,
|
||||
processor.image_seq_len,
|
||||
image_token=image_token,
|
||||
fake_token_around_image=fake_image_token,
|
||||
global_img_token=global_img_token,
|
||||
)
|
||||
image_prompt_strings.append(image_prompt_string)
|
||||
|
||||
split_sample = sample.split(image_token)
|
||||
if len(split_sample) == 0:
|
||||
raise ValueError(
|
||||
"The image token should be present in the text.")
|
||||
|
||||
# Place in the image prompt strings where the image tokens are
|
||||
sample = split_sample[0]
|
||||
for i, image_prompt_string in enumerate(image_prompt_strings):
|
||||
sample += image_prompt_string + split_sample[i + 1]
|
||||
prompt_strings.append(sample)
|
||||
|
||||
prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt=prompt_strings[0],
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
def get_max_idefics3_image_tokens(ctx: InputContext,
|
||||
*,
|
||||
num_crops: Optional[int] = None):
|
||||
model_config = ctx.model_config
|
||||
processor = cached_get_processor(model_config.model)
|
||||
image_seq_len = processor.image_seq_len
|
||||
image_processor = processor.image_processor
|
||||
|
||||
size = image_processor.size['longest_edge']
|
||||
max_image_size = image_processor.max_image_size['longest_edge']
|
||||
resized_height, resized_width = size, size
|
||||
|
||||
grid_h = resized_height // max_image_size
|
||||
grid_w = resized_width // max_image_size
|
||||
|
||||
return (grid_h * grid_w + 1) * image_seq_len
|
||||
|
||||
|
||||
def dummy_data_for_idefics3(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]) -> DummyData:
|
||||
hf_config = ctx.get_hf_config()
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
processor = cached_get_processor(ctx.model_config.model)
|
||||
image_seq_len = processor.image_seq_len
|
||||
max_llm_image_tokens = 17 * image_seq_len * num_images
|
||||
|
||||
seq_data = SequenceData.from_prompt_token_counts(
|
||||
(hf_config.image_token_id, max_llm_image_tokens), (0, seq_len))
|
||||
|
||||
width = height = hf_config.vision_config.image_size
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
mm_data = {"image": [image] if num_images == 1 else [image] * num_images}
|
||||
|
||||
return DummyData(seq_data, mm_data)
|
||||
|
||||
|
||||
class Idefics3SimpleMLP(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
input_size = config.vision_config.hidden_size * (config.scale_factor**
|
||||
2)
|
||||
output_size = config.text_config.hidden_size
|
||||
self.proj = ReplicatedLinear(input_size, output_size, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out, _ = self.proj(x)
|
||||
return out
|
||||
|
||||
|
||||
class Idefics3Connector(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.scale_factor = config.scale_factor
|
||||
self.modality_projection = Idefics3SimpleMLP(config)
|
||||
|
||||
def pixel_shuffle(self,
|
||||
x: torch.Tensor,
|
||||
scale_factor: int = 2) -> torch.Tensor:
|
||||
bsz, seq, embed_dim = x.size()
|
||||
height = width = int(seq**0.5)
|
||||
x = x.view(bsz, height, width, embed_dim)
|
||||
x = x.view(bsz, height, int(width / scale_factor),
|
||||
embed_dim * scale_factor)
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
x = x.reshape(
|
||||
bsz,
|
||||
int(width / scale_factor),
|
||||
int(height / scale_factor),
|
||||
embed_dim * (scale_factor**2),
|
||||
)
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
x = x.reshape(bsz, int(seq / (scale_factor**2)),
|
||||
embed_dim * (scale_factor**2))
|
||||
return x
|
||||
|
||||
def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
image_hidden_states = self.pixel_shuffle(image_hidden_states,
|
||||
self.scale_factor)
|
||||
image_hidden_states = self.modality_projection(image_hidden_states)
|
||||
return image_hidden_states
|
||||
|
||||
|
||||
class Idefics3Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Idefics3Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = self.config.text_config.pad_token_id
|
||||
self.vocab_size = self.config.text_config.vocab_size
|
||||
|
||||
self.vision_model = Idefics3VisionTransformer(config.vision_config,
|
||||
quant_config)
|
||||
self.connector = Idefics3Connector(config)
|
||||
self.text_model = LlamaModel(config.text_config, cache_config,
|
||||
quant_config)
|
||||
|
||||
self.image_seq_len = int(
|
||||
((config.vision_config.image_size //
|
||||
config.vision_config.patch_size)**2) / (config.scale_factor**2))
|
||||
self.image_token_id = self.config.image_token_id
|
||||
|
||||
def _validate_pixel_values(
|
||||
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("num_patches", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values per image per batch "
|
||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
rows = kwargs.pop("rows", None)
|
||||
cols = kwargs.pop("cols", None)
|
||||
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return Idefics3ImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds, concat=True),
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return Idefics3ImagePixelInputs(type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values,
|
||||
concat=True)),
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
pixel_attention_mask=flatten_bn(
|
||||
pixel_attention_mask,
|
||||
concat=True))
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _image_pixels_to_features(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.to(
|
||||
dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
|
||||
) # fp16 compatibility
|
||||
pixel_values = pixel_values.view(batch_size * num_images,
|
||||
*pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(
|
||||
dim=(-1, -2, -3)) != nb_values_per_image
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(pixel_values.size(0), pixel_values.size(2),
|
||||
pixel_values.size(3)),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
batch_size * num_images, *pixel_attention_mask.shape[2:])
|
||||
pixel_attention_mask = pixel_attention_mask[
|
||||
real_images_inds].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1,
|
||||
size=patch_size,
|
||||
step=patch_size)
|
||||
patches_subgrid = patches_subgrid.unfold(dimension=2,
|
||||
size=patch_size,
|
||||
step=patch_size)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
return image_hidden_states
|
||||
|
||||
def _process_image_pixels(
|
||||
self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
|
||||
assert self.vision_model is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
pixel_attention_mask = inputs["pixel_attention_mask"]
|
||||
|
||||
return self._image_pixels_to_features(pixel_values,
|
||||
pixel_attention_mask)
|
||||
|
||||
def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_model is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
return self.connector(image_features)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
else:
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.text_model.get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.image_token_id)
|
||||
else:
|
||||
inputs_embeds = self.text_model.get_input_embeddings(input_ids)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.text_model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_idefics3)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
|
||||
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Idefics3Config,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.model = Idefics3Model(config, cache_config, quant_config)
|
||||
self.image_token_id = self.config.image_token_id
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.text_config.vocab_size,
|
||||
config.text_config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if self.config.text_config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.text_model.wte.weight
|
||||
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
@ -120,6 +120,7 @@ _MULTIMODAL_MODELS = {
|
||||
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
||||
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
||||
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
||||
|
Loading…
x
Reference in New Issue
Block a user