[VLM] Various cleanup and fixes (#14806)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
40253bab44
commit
ab93f1360f
@ -37,6 +37,7 @@ from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import MediaConnector
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -1070,7 +1071,19 @@ def apply_hf_chat_template(
|
||||
tokenize: bool = False, # Different from HF's default
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if chat_template is None and tokenizer.chat_template is None:
|
||||
if chat_template is None:
|
||||
chat_template = tokenizer.chat_template
|
||||
|
||||
# FIXME: Temporary workaround for
|
||||
# https://huggingface.co/mistral-community/pixtral-12b/discussions/31
|
||||
if chat_template is None:
|
||||
try:
|
||||
processor = cached_get_processor(tokenizer.name_or_path)
|
||||
chat_template = processor.chat_template
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if chat_template is None:
|
||||
raise ValueError(
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
|
@ -18,7 +18,7 @@
|
||||
""" PyTorch Fuyu model."""
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict
|
||||
from typing import Literal, Optional, Set, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -31,8 +31,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
@ -58,10 +57,12 @@ class FuyuImagePatchInputs(TypedDict):
|
||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||
"""
|
||||
|
||||
patches_per_image: List[int]
|
||||
patches_per_image: list[int]
|
||||
"""
|
||||
List of number of total patches for each image in the batch.
|
||||
This is used to restore the first two dimensions of `flat_data`.
|
||||
The number of total patches for each image in the batch.
|
||||
|
||||
This is used to split the embeddings which has the first two dimensions
|
||||
flattened just like `flat_data`.
|
||||
"""
|
||||
|
||||
|
||||
@ -317,7 +318,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return None
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: FuyuImagePatchInputs) -> NestedTensors:
|
||||
self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
|
||||
image_patches_flat = image_input["flat_data"]
|
||||
patches_per_image = image_input["patches_per_image"]
|
||||
|
||||
|
@ -5,7 +5,7 @@ from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing_extensions import TypeIs
|
||||
from typing_extensions import Self, TypeIs
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -451,7 +451,7 @@ class SupportsQuant:
|
||||
packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
|
||||
quant_config: Optional[QuantizationConfig] = None
|
||||
|
||||
def __new__(cls, *args, **kwargs) -> "SupportsQuant":
|
||||
def __new__(cls, *args, **kwargs) -> Self:
|
||||
instance = super().__new__(cls)
|
||||
quant_config = cls._find_quant_config(*args, **kwargs)
|
||||
if quant_config is not None:
|
||||
|
@ -3,8 +3,8 @@
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
|
||||
TypedDict, TypeVar, Union, cast)
|
||||
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
|
||||
TypeVar, Union, cast)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -39,8 +39,7 @@ from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .pixtral import (PixtralHFVisionModel,
|
||||
get_pixtral_hf_image_feature_grid_size)
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
@ -49,7 +48,7 @@ from .vision import get_vision_encoder_info
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
pixel_values: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images, num_channels, height, width)`
|
||||
|
||||
@ -57,7 +56,18 @@ class LlavaImagePixelInputs(TypedDict):
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||
|
||||
class PixtralHFImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_pixtral"]
|
||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size * num_images, num_channels, height, width)`
|
||||
|
||||
Note that `height` or `width` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image features correspond
|
||||
to patch tokens.
|
||||
@ -65,7 +75,7 @@ class LlavaImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_crops, num_patch)`
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
@ -73,7 +83,7 @@ class LlavaImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_embeds)`
|
||||
"""
|
||||
|
||||
num_crops: torch.Tensor
|
||||
num_crops: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
@ -85,27 +95,9 @@ class LlavaImageEmbeddingInputs(TypedDict):
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image features correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_crops, num_patch)`
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_embeds)`
|
||||
"""
|
||||
|
||||
num_crops: torch.Tensor
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
|
||||
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
|
||||
LlavaImageEmbeddingInputs]
|
||||
|
||||
|
||||
class LlavaMultiModalProjector(nn.Module):
|
||||
@ -357,13 +349,15 @@ class PixtralHFMultiModalProcessor(
|
||||
]
|
||||
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
encoder_info = PixtralHFEncoderInfo(vision_config)
|
||||
|
||||
tile_sizes = [
|
||||
get_pixtral_hf_image_feature_grid_size(
|
||||
hf_config.vision_config,
|
||||
encoder_info.get_patch_grid_size(
|
||||
image_width=pixel_value.shape[-1],
|
||||
image_height=pixel_value.shape[-2])
|
||||
for pixel_value in processed_outputs["pixel_values"]
|
||||
image_height=pixel_value.shape[-2],
|
||||
) for pixel_value in processed_outputs["pixel_values"]
|
||||
]
|
||||
num_crops = torch.tensor([(ncols + 1) * nrows
|
||||
for ncols, nrows in tile_sizes])
|
||||
@ -411,13 +405,13 @@ class PixtralHFMultiModalProcessor(
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
encoder_info = PixtralHFEncoderInfo(vision_config)
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
ncols, nrows = get_pixtral_hf_image_feature_grid_size(
|
||||
vision_config,
|
||||
ncols, nrows = encoder_info.get_patch_grid_size(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
@ -512,7 +506,7 @@ def init_vision_tower_for_llava(
|
||||
*,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# Initialize the vision tower only up to the deepest required feature layer
|
||||
@ -627,32 +621,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
feat_is_patch = kwargs.pop("feat_is_patch", None)
|
||||
if feat_is_patch is not None and not isinstance(
|
||||
feat_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of feat_is_patch. "
|
||||
f"Got type: {type(feat_is_patch)}")
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
if embed_is_patch is not None and not isinstance(
|
||||
embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
if num_crops is not None and not isinstance(num_crops, torch.Tensor):
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if self.config.vision_config.model_type == "pixtral":
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=flatten_bn(pixel_values),
|
||||
feat_is_patch = kwargs.pop("feat_is_patch")
|
||||
if not isinstance(feat_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of feat_is_patch. "
|
||||
f"Got type: {type(feat_is_patch)}")
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_crops = kwargs.pop("num_crops")
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
|
||||
return PixtralHFImagePixelInputs(
|
||||
type="pixel_values_pixtral",
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
feat_is_patch=feat_is_patch,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_crops=num_crops,
|
||||
@ -660,11 +652,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
pixel_values=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
feat_is_patch=feat_is_patch,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_crops=num_crops,
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
@ -672,12 +661,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
if self.config.vision_config.model_type == "pixtral":
|
||||
raise ValueError("Pixtral-HF does not support image_embeds.")
|
||||
|
||||
return LlavaImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds, concat=True),
|
||||
feat_is_patch=feat_is_patch,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_crops=num_crops,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
@ -696,7 +685,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
|
||||
PixtralHFVisionModel],
|
||||
pixel_values: torch.Tensor,
|
||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
@ -708,17 +697,20 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
strategy=self.config.vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
def _process_image_pixels(self,
|
||||
inputs: LlavaImagePixelInputs) -> torch.Tensor:
|
||||
def _process_image_pixels(
|
||||
self,
|
||||
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
|
||||
) -> torch.Tensor:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
return self._image_pixels_to_features(self.vision_tower, pixel_values)
|
||||
|
||||
def _process_image_input(self,
|
||||
image_input: LlavaImageInputs) -> torch.Tensor:
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: LlavaImageInputs,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
@ -783,11 +775,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
|
||||
if kwargs.get("v0_path", False) or \
|
||||
image_input.get("feat_is_patch") is None or \
|
||||
image_input.get("embed_is_patch") is None:
|
||||
if (kwargs.get("v0_path", False)
|
||||
or image_input["type"] != "pixel_values_pixtral"):
|
||||
# The path is used for pixtral (V0 only) and llava (V0/V1)
|
||||
return vision_embeddings
|
||||
|
||||
|
@ -32,7 +32,7 @@ from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
|
||||
|
||||
class LlavaNextImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
|
||||
@ -315,7 +315,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return LlavaNextImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(flatten_bn(pixel_values)),
|
||||
pixel_values=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values)),
|
||||
image_sizes=self._validate_image_sizes(
|
||||
flatten_bn(image_sizes, concat=True)),
|
||||
)
|
||||
@ -434,7 +435,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
if isinstance(pixel_values, torch.Tensor):
|
||||
b, num_patches, c, h, w = pixel_values.shape
|
||||
|
@ -42,7 +42,7 @@ _MAX_FRAMES_PER_VIDEO = 16
|
||||
|
||||
class LlavaOnevisionVideoPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_videos"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)`
|
||||
|
||||
@ -54,7 +54,7 @@ class LlavaOnevisionVideoPixelInputs(TypedDict):
|
||||
|
||||
class LlavaOnevisionImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
|
||||
@ -521,7 +521,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return LlavaOnevisionImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_image_pixel_values(
|
||||
pixel_values=self._validate_image_pixel_values(
|
||||
flatten_bn(pixel_values)),
|
||||
image_sizes=self._validate_image_sizes(
|
||||
flatten_bn(image_sizes, concat=True)),
|
||||
@ -570,21 +570,20 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
List[b, Tensor(nb_frames, nb_channels, height, width)]
|
||||
}
|
||||
"""
|
||||
pixel_values = kwargs.pop("pixel_values_videos", None)
|
||||
|
||||
if pixel_values is None:
|
||||
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
|
||||
if pixel_values_videos is None:
|
||||
return None
|
||||
|
||||
if not (is_list_of(pixel_values,
|
||||
(torch.Tensor)) # different shape videos
|
||||
or isinstance(pixel_values,
|
||||
if not (is_list_of(pixel_values_videos,
|
||||
torch.Tensor) # different shape videos
|
||||
or isinstance(pixel_values_videos,
|
||||
torch.Tensor)): # same shape videos
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
raise ValueError("Incorrect type of pixel_values_videos. "
|
||||
f"Got type: {type(pixel_values_videos)}")
|
||||
|
||||
return LlavaOnevisionVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
data=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
@ -723,7 +722,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
if isinstance(pixel_values, torch.Tensor):
|
||||
b, num_patches, c, h, w = pixel_values.shape
|
||||
@ -757,7 +756,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
image_sizes = image_input.get("image_sizes")
|
||||
if image_sizes is None:
|
||||
batch_size = len(image_input["data"])
|
||||
batch_size = len(image_input["pixel_values"])
|
||||
vision_config = self.config.vision_config
|
||||
default_height = default_width = vision_config.image_size
|
||||
image_sizes = torch.as_tensor([[default_height, default_width]
|
||||
@ -808,7 +807,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
|
||||
assert self.vision_tower is not None
|
||||
|
||||
video_pixels = inputs["data"]
|
||||
video_pixels = inputs["pixel_values_videos"]
|
||||
|
||||
if isinstance(video_pixels, torch.Tensor):
|
||||
b, num_videos, frames, c, h, w = video_pixels.shape
|
||||
|
@ -23,7 +23,6 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@ -36,11 +35,12 @@ from transformers.models.whisper.modeling_whisper import (
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||
from vllm.multimodal.parse import (AudioItem, DictEmbeddingItems, ModalityData,
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
|
||||
from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
|
||||
DictEmbeddingItems, ModalityData,
|
||||
ModalityDataItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import PromptReplacement
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -272,8 +272,13 @@ class MiniCPMOMultiModalProcessor(
|
||||
tokenizer.audio_end_id)
|
||||
return special_tokens
|
||||
|
||||
def process_audios(self, mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object]) -> Dict[str, object]:
|
||||
def process_audios(
|
||||
self,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
mm_data = dict(mm_data)
|
||||
|
||||
audios = mm_data.pop("audios", [])
|
||||
audio_embeds = mm_data.pop("audio_embeds", [])
|
||||
if isinstance(audios, (list, torch.Tensor)) and len(audios) > 0:
|
||||
@ -332,11 +337,15 @@ class MiniCPMOMultiModalProcessor(
|
||||
def get_placeholder_split_pattern(self) -> str:
|
||||
return r"\(<(?:image|video|audio)>./</(?:image|video|audio)>\)"
|
||||
|
||||
def process_mm_inputs(self, mm_data, mm_kwargs) -> object:
|
||||
def process_mm_inputs(
|
||||
self,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, Mapping[str, NestedTensors]]:
|
||||
return {
|
||||
"image": self.process_images(mm_data, mm_kwargs),
|
||||
"video": self.process_videos(mm_data, mm_kwargs),
|
||||
"audio": self.process_audios(mm_data, mm_kwargs)
|
||||
"audio": self.process_audios(mm_data, mm_kwargs),
|
||||
}
|
||||
|
||||
def get_modality_num_counter(self, modality: str) -> str:
|
||||
@ -358,39 +367,38 @@ class MiniCPMOMultiModalProcessor(
|
||||
return super().get_prompt_texts_by_modality(inputs, modality, index)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self, mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]:
|
||||
placeholder = {
|
||||
"image": self.info.image_pattern,
|
||||
"video": self.info.video_pattern,
|
||||
"audio": self.info.audio_pattern
|
||||
}
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
base_updates = super()._get_prompt_updates(
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
out_mm_kwargs=out_mm_kwargs,
|
||||
)
|
||||
|
||||
def get_replacement_minicpmv(item_idx: int, modality: str):
|
||||
if modality == "image":
|
||||
return self.get_image_prompt_texts(
|
||||
mm_items["image"].get_image_size(item_idx), item_idx)
|
||||
elif modality == "video":
|
||||
return self.get_video_prompt_texts(
|
||||
mm_items["video"].get_frame_size(item_idx),
|
||||
mm_items["video"].get_num_frames(item_idx))
|
||||
else: # audio
|
||||
if isinstance(mm_items["audio"], MiniCPMOAudioEmbeddingItems):
|
||||
single_audio_embeds = mm_items["audio"].get(item_idx)
|
||||
audio_len = self.info.get_audio_len_by_num_chunks(
|
||||
sum(chunk_embeds.shape[0]
|
||||
for chunk_embeds in single_audio_embeds))
|
||||
return self.get_audio_prompt_texts(audio_len)
|
||||
return self.get_audio_prompt_texts(
|
||||
len(mm_items["audio"].get(item_idx)))
|
||||
audio_placeholder = self.info.audio_pattern
|
||||
|
||||
def get_audio_replacement(item_idx: int):
|
||||
audios = mm_items.get_items(
|
||||
"audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems))
|
||||
|
||||
if isinstance(audios, MiniCPMOAudioEmbeddingItems):
|
||||
single_audio_embeds = audios.get(item_idx)["audio_embeds"]
|
||||
audio_len = self.info.get_audio_len_by_num_chunks(
|
||||
sum(chunk_embeds.shape[0]
|
||||
for chunk_embeds in single_audio_embeds))
|
||||
else:
|
||||
audio_len = audios.get_audio_length(item_idx)
|
||||
|
||||
return self.get_audio_prompt_texts(audio_len)
|
||||
|
||||
return [
|
||||
PromptReplacement(modality=modality,
|
||||
target=placeholder[modality],
|
||||
replacement=partial(get_replacement_minicpmv,
|
||||
modality=modality))
|
||||
for modality in ("image", "video", "audio")
|
||||
*base_updates,
|
||||
PromptReplacement(modality="audio",
|
||||
target=audio_placeholder,
|
||||
replacement=get_audio_replacement),
|
||||
]
|
||||
|
||||
def _get_mm_fields_config(
|
||||
|
@ -24,7 +24,6 @@
|
||||
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
||||
import math
|
||||
import re
|
||||
from collections import Counter
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property, partial
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
|
||||
@ -51,13 +50,16 @@ from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, PlaceholderRange)
|
||||
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageSize,
|
||||
MultiModalInputs, NestedTensors,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
|
||||
ImageProcessorItems, ImageSize,
|
||||
ModalityData, ModalityDataItems,
|
||||
MultiModalDataItems, MultiModalDataParser,
|
||||
VideoItem)
|
||||
VideoItem, VideoProcessorItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -557,8 +559,13 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
outputs = {key: outputs[key][0] for key in valid_keys}
|
||||
return outputs
|
||||
|
||||
def process_images(self, mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object]) -> Dict[str, object]:
|
||||
def process_images(
|
||||
self,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
mm_data = dict(mm_data)
|
||||
|
||||
images = mm_data.pop("images", [])
|
||||
image_embeds = mm_data.pop("image_embeds", [])
|
||||
if isinstance(images, Image.Image):
|
||||
@ -568,8 +575,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
prompt=self.info.image_pattern * len(images),
|
||||
mm_data={"images": images},
|
||||
mm_kwargs=mm_kwargs)
|
||||
image_outputs = MiniCPMVMultiModalProcessor.\
|
||||
repack_processor_outputs(image_outputs)
|
||||
image_outputs = self.repack_processor_outputs(image_outputs)
|
||||
elif len(image_embeds) > 0:
|
||||
image_sizes = mm_data.pop("image_sizes", None)
|
||||
image_outputs = {
|
||||
@ -580,8 +586,13 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
image_outputs = {}
|
||||
return image_outputs
|
||||
|
||||
def process_videos(self, mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object]) -> Dict[str, object]:
|
||||
def process_videos(
|
||||
self,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
mm_data = dict(mm_data)
|
||||
|
||||
videos = mm_data.pop("videos", [])
|
||||
video_embeds = mm_data.pop("video_embeds", [])
|
||||
if len(videos) > 0 and isinstance(videos[0], Image.Image):
|
||||
@ -635,10 +646,14 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
def get_placeholder_split_pattern(self) -> str:
|
||||
return r"\(<(?:image|video)>./</(?:image|video)>\)"
|
||||
|
||||
def process_mm_inputs(self, mm_data, mm_kwargs) -> object:
|
||||
def process_mm_inputs(
|
||||
self,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, Mapping[str, NestedTensors]]:
|
||||
return {
|
||||
"image": self.process_images(mm_data, mm_kwargs),
|
||||
"video": self.process_videos(mm_data, mm_kwargs)
|
||||
"video": self.process_videos(mm_data, mm_kwargs),
|
||||
}
|
||||
|
||||
def get_input_modalities(self, mm_data) -> List[str]:
|
||||
@ -655,8 +670,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
elif modality == "video":
|
||||
return "video_image_sizes"
|
||||
|
||||
def get_num_slices_by_modality(self, inputs: Dict[str, object],
|
||||
modality: str, index: int) -> int:
|
||||
raise NotImplementedError(modality)
|
||||
|
||||
def get_num_slices_by_modality(self, inputs: dict[str, Any], modality: str,
|
||||
index: int) -> int:
|
||||
if modality == "image":
|
||||
return self.info.get_image_slice_nums(
|
||||
inputs[modality]["image_sizes"][index],
|
||||
@ -669,20 +686,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
else:
|
||||
raise ValueError(f"Unexpected modality: {modality}")
|
||||
|
||||
def check_mm_inputs(self, inputs: Dict[str, object],
|
||||
matches: List[str]) -> None:
|
||||
counts = Counter(matches)
|
||||
for modality, count in counts.items():
|
||||
if modality not in inputs or not inputs[modality]:
|
||||
raise ValueError(f"None input data of {modality}."
|
||||
" But prompt requires.")
|
||||
counter_key = self.get_modality_num_counter(modality)
|
||||
if len(inputs[modality][counter_key]) != count:
|
||||
raise ValueError(f"The prompt requires {count} "
|
||||
f"{modality} inputs while you pass "
|
||||
f"{len(inputs[modality][counter_key])}")
|
||||
|
||||
def get_prompt_texts_by_modality(self, inputs: Dict[str, object],
|
||||
def get_prompt_texts_by_modality(self, inputs: dict[str, Any],
|
||||
modality: str, index: int) -> str:
|
||||
if modality == "image":
|
||||
return self.get_image_prompt_texts(
|
||||
@ -715,13 +719,23 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
inputs = self.process_mm_inputs(mm_data, mm_kwargs)
|
||||
mm_input_modalities = self.get_input_modalities(inputs)
|
||||
num_mm_slices = {modality: [] for modality in mm_input_modalities}
|
||||
|
||||
num_mm_slices_lst = {
|
||||
modality: list[int]()
|
||||
for modality in mm_input_modalities
|
||||
}
|
||||
for modality in mm_input_modalities:
|
||||
num_counter_key = self.get_modality_num_counter(modality)
|
||||
for index in range(len(inputs[modality][num_counter_key])):
|
||||
num_mm_slices[modality].append(
|
||||
num_mm_slices_lst[modality].append(
|
||||
self.get_num_slices_by_modality(inputs, modality, index))
|
||||
return {
|
||||
|
||||
num_mm_slices = {
|
||||
modality: torch.tensor(v)
|
||||
for modality, v in num_mm_slices_lst.items()
|
||||
}
|
||||
|
||||
return BatchFeature({
|
||||
"input_ids": np.array([tokenizer.encode(prompt)]),
|
||||
**{
|
||||
key: value
|
||||
@ -732,7 +746,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
f"{modality}_num_slices": num_mm_slices[modality]
|
||||
for modality in mm_input_modalities
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
@ -743,28 +757,42 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
return False
|
||||
|
||||
def _get_prompt_updates(
|
||||
self, mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]:
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
placeholder = {
|
||||
"image": self.info.image_pattern,
|
||||
"video": self.info.video_pattern,
|
||||
}
|
||||
|
||||
def get_replacement_minicpmv(item_idx: int, modality: str):
|
||||
if modality == "image":
|
||||
return self.get_image_prompt_texts(
|
||||
mm_items["image"].get_image_size(item_idx), item_idx)
|
||||
else: # video
|
||||
return self.get_video_prompt_texts(
|
||||
mm_items["video"].get_frame_size(item_idx),
|
||||
mm_items["video"].get_num_frames(item_idx))
|
||||
def get_image_replacement(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
return self.get_image_prompt_texts(image_size, item_idx)
|
||||
|
||||
def get_video_replacement(item_idx: int):
|
||||
videos = mm_items.get_items(
|
||||
"video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems))
|
||||
|
||||
frame_size = videos.get_frame_size(item_idx)
|
||||
num_frames = videos.get_num_frames(item_idx)
|
||||
|
||||
return self.get_video_prompt_texts(frame_size, num_frames)
|
||||
|
||||
get_replacement = {
|
||||
"image": get_image_replacement,
|
||||
"video": get_video_replacement,
|
||||
}
|
||||
|
||||
return [
|
||||
PromptReplacement(modality=modality,
|
||||
target=placeholder[modality],
|
||||
replacement=partial(get_replacement_minicpmv,
|
||||
modality=modality))
|
||||
replacement=get_replacement[modality])
|
||||
for modality in ("image", "video")
|
||||
]
|
||||
|
||||
|
@ -1478,7 +1478,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
if not isinstance(num_crops, torch.Tensor):
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import cached_property
|
||||
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
|
||||
from typing import List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -683,79 +684,6 @@ class VisionLanguageAdapter(nn.Module):
|
||||
# and [`MistralForCausalLM`] for its language decoder.
|
||||
|
||||
|
||||
def get_pixtral_hf_patch_grid_length(*, image_size: int,
|
||||
patch_size: int) -> int:
|
||||
# Since interpolation is applied, the image size need not be divisible
|
||||
# assert image_size % patch_size == 0
|
||||
return image_size // patch_size
|
||||
|
||||
|
||||
def get_pixtral_hf_image_feature_size(
|
||||
*,
|
||||
image_size: int,
|
||||
patch_size: int,
|
||||
) -> int:
|
||||
grid_length = get_pixtral_hf_patch_grid_length(
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
|
||||
# Consider the image_break_token
|
||||
return (grid_length + 1) * grid_length
|
||||
|
||||
|
||||
def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
|
||||
grid_length = get_pixtral_hf_patch_grid_length(
|
||||
image_size=hf_config.image_size,
|
||||
patch_size=hf_config.patch_size,
|
||||
)
|
||||
|
||||
# Consider the image_break_token
|
||||
return (grid_length + 1) * grid_length
|
||||
|
||||
|
||||
def dummy_image_for_pixtral_hf(
|
||||
hf_config: PixtralVisionConfig,
|
||||
num_images: int,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
):
|
||||
width = height = hf_config.image_size
|
||||
if image_width_override is not None:
|
||||
width = image_width_override
|
||||
if image_height_override is not None:
|
||||
height = image_height_override
|
||||
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
|
||||
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180
|
||||
def get_pixtral_hf_image_feature_grid_size(
|
||||
hf_config: PixtralVisionConfig,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> tuple[int, int]:
|
||||
max_width = max_height = hf_config.image_size
|
||||
patch_width = patch_height = hf_config.patch_size
|
||||
|
||||
ratio = max(image_width / max_width, image_height / max_height)
|
||||
|
||||
if ratio > 1:
|
||||
image_width = int(math.ceil(image_width / ratio))
|
||||
image_height = int(math.ceil(image_height / ratio))
|
||||
|
||||
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
||||
(image_height, image_width),
|
||||
(patch_height, patch_width),
|
||||
) # type: ignore
|
||||
|
||||
return ncols, nrows
|
||||
|
||||
|
||||
class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
||||
|
||||
def get_num_image_tokens(
|
||||
@ -764,13 +692,21 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
return get_pixtral_hf_image_feature_size(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.vision_config.patch_size,
|
||||
ncols, nrows = self.get_patch_grid_size(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
)
|
||||
|
||||
# Consider the image_break_token
|
||||
return (ncols + 1) * nrows
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
return get_max_pixtral_hf_image_tokens(self.vision_config)
|
||||
image_size = self.get_image_size()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=image_size,
|
||||
image_height=image_size,
|
||||
)
|
||||
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
@ -779,10 +715,34 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
||||
return self.vision_config.patch_size
|
||||
|
||||
def get_patch_grid_length(self) -> int:
|
||||
return get_pixtral_hf_patch_grid_length(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.vision_config.patch_size,
|
||||
)
|
||||
image_size, patch_size = self.get_image_size(), self.get_patch_size()
|
||||
|
||||
# Since interpolation is applied, the image size need not be divisible
|
||||
# assert image_size % patch_size == 0
|
||||
return image_size // patch_size
|
||||
|
||||
# Adapted from: https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/pixtral/image_processing_pixtral.py#L99
|
||||
def get_patch_grid_size(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> tuple[int, int]:
|
||||
max_width = max_height = self.get_image_size()
|
||||
patch_width = patch_height = self.get_patch_size()
|
||||
|
||||
ratio = max(image_width / max_width, image_height / max_height)
|
||||
|
||||
if ratio > 1:
|
||||
image_width = int(math.ceil(image_width / ratio))
|
||||
image_height = int(math.ceil(image_height / ratio))
|
||||
|
||||
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
||||
(image_height, image_width),
|
||||
(patch_height, patch_width),
|
||||
) # type: ignore
|
||||
|
||||
return ncols, nrows
|
||||
|
||||
|
||||
class PixtralHFMLP(nn.Module):
|
||||
|
@ -222,10 +222,10 @@ class Qwen2AudioMultiModalProcessor(
|
||||
num_features = audio_output_lengths[item_idx]
|
||||
if num_features == 0:
|
||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||
audio = audios.get(item_idx)
|
||||
raise ValueError(
|
||||
f"The audio {audio} (len={len(audio)}) is too short "
|
||||
"to be represented inside the model")
|
||||
audio_len = audios.get_audio_length(item_idx)
|
||||
|
||||
raise ValueError(f"The audio (len={audio_len}) is too short "
|
||||
"to be represented inside the model")
|
||||
|
||||
audio_tokens = [audio_token_id] * num_features
|
||||
|
||||
|
@ -433,6 +433,10 @@ class MultiModalFieldConfig:
|
||||
:func:`MultiModalFieldConfig.flat`
|
||||
"""
|
||||
|
||||
if size_per_item.ndim != 1:
|
||||
raise ValueError("size_per_item should be a 1-D tensor, "
|
||||
f"but found shape: {size_per_item.shape}")
|
||||
|
||||
slice_idxs = [0, *accumulate(size_per_item)]
|
||||
slices = [
|
||||
slice(slice_idxs[i], slice_idxs[i + 1])
|
||||
|
@ -176,6 +176,10 @@ class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
|
||||
def __init__(self, data: Sequence[HfAudioItem]) -> None:
|
||||
super().__init__(data, "audio")
|
||||
|
||||
def get_audio_length(self, item_idx: int) -> int:
|
||||
audio = self.get(item_idx)
|
||||
return len(audio)
|
||||
|
||||
|
||||
class AudioEmbeddingItems(EmbeddingItems):
|
||||
|
||||
|
@ -1311,8 +1311,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
def _bind_and_group_updates(
|
||||
self,
|
||||
prompt_updates: list[PromptUpdate],
|
||||
) -> dict[str, list[BoundPromptUpdate]]:
|
||||
prompt_updates: Sequence[PromptUpdate],
|
||||
) -> dict[str, Sequence[BoundPromptUpdate]]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
it = (update.bind(tokenizer) for update in prompt_updates)
|
||||
|
Loading…
x
Reference in New Issue
Block a user