[VLM] Merged multi-modal processor for Pixtral (#12211)
Signed-off-by: remi <remi@mistral.ai> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
74bc397b0a
commit
61c6a5a796
@ -43,12 +43,18 @@ from vllm.sampling_params import SamplingParams
|
|||||||
# python demo.py advanced
|
# python demo.py advanced
|
||||||
|
|
||||||
|
|
||||||
def run_simple_demo():
|
def run_simple_demo(args: argparse.Namespace):
|
||||||
model_name = "mistralai/Pixtral-12B-2409"
|
model_name = "mistralai/Pixtral-12B-2409"
|
||||||
sampling_params = SamplingParams(max_tokens=8192)
|
sampling_params = SamplingParams(max_tokens=8192)
|
||||||
|
|
||||||
# Lower max_num_seqs or max_model_len on low-VRAM GPUs.
|
# Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
|
||||||
llm = LLM(model=model_name, tokenizer_mode="mistral")
|
llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
tokenizer_mode="mistral",
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=2,
|
||||||
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
|
)
|
||||||
|
|
||||||
prompt = "Describe this image in one sentence."
|
prompt = "Describe this image in one sentence."
|
||||||
image_url = "https://picsum.photos/id/237/200/300"
|
image_url = "https://picsum.photos/id/237/200/300"
|
||||||
@ -76,7 +82,7 @@ def run_simple_demo():
|
|||||||
print(outputs[0].outputs[0].text)
|
print(outputs[0].outputs[0].text)
|
||||||
|
|
||||||
|
|
||||||
def run_advanced_demo():
|
def run_advanced_demo(args: argparse.Namespace):
|
||||||
model_name = "mistralai/Pixtral-12B-2409"
|
model_name = "mistralai/Pixtral-12B-2409"
|
||||||
max_img_per_msg = 5
|
max_img_per_msg = 5
|
||||||
max_tokens_per_img = 4096
|
max_tokens_per_img = 4096
|
||||||
@ -87,6 +93,7 @@ def run_advanced_demo():
|
|||||||
tokenizer_mode="mistral",
|
tokenizer_mode="mistral",
|
||||||
limit_mm_per_prompt={"image": max_img_per_msg},
|
limit_mm_per_prompt={"image": max_img_per_msg},
|
||||||
max_model_len=max_img_per_msg * max_tokens_per_img,
|
max_model_len=max_img_per_msg * max_tokens_per_img,
|
||||||
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = "Describe the following image."
|
prompt = "Describe the following image."
|
||||||
@ -153,14 +160,19 @@ def main():
|
|||||||
help="Specify the demo mode: 'simple' or 'advanced'",
|
help="Specify the demo mode: 'simple' or 'advanced'",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--disable-mm-preprocessor-cache',
|
||||||
|
action='store_true',
|
||||||
|
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.mode == "simple":
|
if args.mode == "simple":
|
||||||
print("Running simple demo...")
|
print("Running simple demo...")
|
||||||
run_simple_demo()
|
run_simple_demo(args)
|
||||||
elif args.mode == "advanced":
|
elif args.mode == "advanced":
|
||||||
print("Running advanced demo...")
|
print("Running advanced demo...")
|
||||||
run_advanced_demo()
|
run_advanced_demo(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -2,17 +2,23 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
||||||
|
UserMessage)
|
||||||
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.inputs import InputProcessingContext
|
from vllm.inputs import InputProcessingContext
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||||
from vllm.multimodal.processing import ProcessingCache
|
from vllm.multimodal.inputs import MultiModalInputs
|
||||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
|
||||||
|
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||||
|
cached_tokenizer_from_config)
|
||||||
|
|
||||||
from ....multimodal.utils import random_audio, random_image, random_video
|
from ....multimodal.utils import random_audio, random_image, random_video
|
||||||
from ...registry import HF_EXAMPLE_MODELS
|
from ...registry import HF_EXAMPLE_MODELS
|
||||||
@ -85,14 +91,6 @@ def _test_processing_correctness(
|
|||||||
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
|
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenizer_encode_kwargs = {}
|
|
||||||
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
|
|
||||||
# For some multimodal models, tokenizer will always add bos_token
|
|
||||||
# at the beginning of prompt by default, causing hf_processor outputs
|
|
||||||
# incorrect token ids. So we need use `add_special_tokens=False` here
|
|
||||||
# to leave bos_token to be added by the processor.
|
|
||||||
tokenizer_encode_kwargs = {"add_special_tokens": False}
|
|
||||||
|
|
||||||
for batch_idx in range(num_batches):
|
for batch_idx in range(num_batches):
|
||||||
mm_data = {
|
mm_data = {
|
||||||
k:
|
k:
|
||||||
@ -115,43 +113,131 @@ def _test_processing_correctness(
|
|||||||
elif len(mm_data[k]) == 1:
|
elif len(mm_data[k]) == 1:
|
||||||
mm_data[k] = mm_data[k][0]
|
mm_data[k] = mm_data[k][0]
|
||||||
|
|
||||||
baseline_result = baseline_processor.apply(
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
prompt,
|
_test_processing_correctness_mistral(
|
||||||
mm_data=mm_data,
|
model_config,
|
||||||
hf_processor_mm_kwargs={},
|
tokenizer,
|
||||||
)
|
prompt,
|
||||||
cached_result = cached_processor.apply(
|
mm_data,
|
||||||
prompt,
|
baseline_processor,
|
||||||
mm_data=mm_data,
|
cached_processor,
|
||||||
hf_processor_mm_kwargs={},
|
batch_idx,
|
||||||
)
|
ignore_mm_keys=ignore_mm_keys,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_test_processing_correctness_hf(
|
||||||
|
model_config,
|
||||||
|
tokenizer,
|
||||||
|
prompt,
|
||||||
|
mm_data,
|
||||||
|
baseline_processor,
|
||||||
|
cached_processor,
|
||||||
|
batch_idx,
|
||||||
|
ignore_mm_keys=ignore_mm_keys,
|
||||||
|
)
|
||||||
|
|
||||||
assert _drop_mm_kwargs_keys(
|
|
||||||
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
|
||||||
cached_result, ignore_mm_keys), (
|
|
||||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
|
||||||
|
|
||||||
baseline_tokenized_result = baseline_processor.apply(
|
def _test_processing_correctness_hf(
|
||||||
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
|
model_config: ModelConfig,
|
||||||
mm_data=mm_data,
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
hf_processor_mm_kwargs={},
|
prompt: str,
|
||||||
)
|
mm_data: MultiModalDataDict,
|
||||||
|
baseline_processor: BaseMultiModalProcessor,
|
||||||
|
cached_processor: BaseMultiModalProcessor,
|
||||||
|
batch_idx: int,
|
||||||
|
ignore_mm_keys: Optional[list[str]] = None,
|
||||||
|
):
|
||||||
|
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
|
||||||
|
# For some multimodal models, tokenizer will always add bos_token
|
||||||
|
# at the beginning of prompt by default, causing hf_processor outputs
|
||||||
|
# incorrect token ids. So we need use `add_special_tokens=False` here
|
||||||
|
# to leave bos_token to be added by the processor.
|
||||||
|
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
|
else:
|
||||||
|
token_prompt = tokenizer.encode(prompt)
|
||||||
|
|
||||||
assert _drop_mm_kwargs_keys(
|
baseline_result = baseline_processor.apply(
|
||||||
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
prompt,
|
||||||
baseline_tokenized_result, ignore_mm_keys), (
|
mm_data=mm_data,
|
||||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
hf_processor_mm_kwargs={},
|
||||||
|
)
|
||||||
|
cached_result = cached_processor.apply(
|
||||||
|
prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
hf_processor_mm_kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
cached_tokenized_result = cached_processor.apply(
|
assert _inputs_equal(
|
||||||
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
|
baseline_result,
|
||||||
mm_data=mm_data,
|
cached_result,
|
||||||
hf_processor_mm_kwargs={},
|
ignore_mm_keys,
|
||||||
)
|
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||||
|
|
||||||
assert _drop_mm_kwargs_keys(
|
baseline_tokenized_result = baseline_processor.apply(
|
||||||
cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
token_prompt,
|
||||||
cached_tokenized_result, ignore_mm_keys), (
|
mm_data=mm_data,
|
||||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
hf_processor_mm_kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _inputs_equal(
|
||||||
|
baseline_result,
|
||||||
|
baseline_tokenized_result,
|
||||||
|
ignore_mm_keys,
|
||||||
|
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||||
|
|
||||||
|
cached_tokenized_result = cached_processor.apply(
|
||||||
|
token_prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
hf_processor_mm_kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _inputs_equal(
|
||||||
|
cached_result,
|
||||||
|
cached_tokenized_result,
|
||||||
|
ignore_mm_keys,
|
||||||
|
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||||
|
|
||||||
|
|
||||||
|
def _test_processing_correctness_mistral(
|
||||||
|
model_config: ModelConfig,
|
||||||
|
tokenizer: MistralTokenizer,
|
||||||
|
prompt: str,
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
baseline_processor: BaseMultiModalProcessor,
|
||||||
|
cached_processor: BaseMultiModalProcessor,
|
||||||
|
batch_idx: int,
|
||||||
|
ignore_mm_keys: Optional[list[str]] = None,
|
||||||
|
):
|
||||||
|
images = mm_data.get("image", [])
|
||||||
|
if not isinstance(images, list):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
request = ChatCompletionRequest(messages=[
|
||||||
|
UserMessage(content=[
|
||||||
|
TextChunk(text=prompt),
|
||||||
|
*(ImageChunk(image=image) for image in images),
|
||||||
|
]),
|
||||||
|
])
|
||||||
|
res = tokenizer.mistral.encode_chat_completion(request)
|
||||||
|
token_prompt = res.tokens
|
||||||
|
|
||||||
|
# Mistral chat outputs tokens directly, rather than text prompts
|
||||||
|
baseline_tokenized_result = baseline_processor.apply(
|
||||||
|
token_prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
hf_processor_mm_kwargs={},
|
||||||
|
)
|
||||||
|
cached_tokenized_result = cached_processor.apply(
|
||||||
|
token_prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
hf_processor_mm_kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _inputs_equal(
|
||||||
|
baseline_tokenized_result,
|
||||||
|
cached_tokenized_result,
|
||||||
|
ignore_mm_keys,
|
||||||
|
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -173,6 +259,7 @@ def _test_processing_correctness(
|
|||||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||||
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||||
|
"mistralai/Pixtral-12B-2409",
|
||||||
"mistral-community/pixtral-12b",
|
"mistral-community/pixtral-12b",
|
||||||
"openbmb/MiniCPM-o-2_6",
|
"openbmb/MiniCPM-o-2_6",
|
||||||
"openbmb/MiniCPM-V-2_6",
|
"openbmb/MiniCPM-V-2_6",
|
||||||
@ -241,8 +328,19 @@ def test_processing_correctness_phi3v(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _drop_mm_kwargs_keys(result: dict,
|
def _inputs_equal(
|
||||||
ignore_mm_keys: Optional[list[str]] = None) -> dict:
|
a: MultiModalInputs,
|
||||||
|
b: MultiModalInputs,
|
||||||
|
ignore_mm_keys: Optional[list[str]] = None,
|
||||||
|
):
|
||||||
|
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||||
|
b, ignore_mm_keys)
|
||||||
|
|
||||||
|
|
||||||
|
def _drop_mm_kwargs_keys(
|
||||||
|
result: MultiModalInputs,
|
||||||
|
ignore_mm_keys: Optional[list[str]] = None,
|
||||||
|
) -> MultiModalInputs:
|
||||||
"""Drop specified keys from result['mm_kwargs'].
|
"""Drop specified keys from result['mm_kwargs'].
|
||||||
|
|
||||||
This is mainly to avoid doing exact match of audio_features in ultravox.
|
This is mainly to avoid doing exact match of audio_features in ultravox.
|
||||||
|
@ -68,23 +68,15 @@ class PixtralHFImagePixelInputs(TypedDict):
|
|||||||
in which case the data is passed as a list instead of a batched tensor.
|
in which case the data is passed as a list instead of a batched tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
|
||||||
"""
|
|
||||||
A boolean mask indicating which image features correspond
|
|
||||||
to patch tokens.
|
|
||||||
|
|
||||||
Shape: `(batch_size, num_crops, num_patch)`
|
|
||||||
"""
|
|
||||||
|
|
||||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
A boolean mask indicating which image embeddings correspond
|
A boolean mask indicating which image embeddings correspond
|
||||||
to patch tokens.
|
to patch tokens.
|
||||||
|
|
||||||
Shape: `(batch_size, num_embeds)`
|
Shape: `(batch_size, num_images, num_embeds)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_crops: Union[torch.Tensor, list[torch.Tensor]]
|
num_patches: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""Shape: `(batch_size, num_images)`"""
|
"""Shape: `(batch_size, num_images)`"""
|
||||||
|
|
||||||
|
|
||||||
@ -360,16 +352,16 @@ class PixtralHFMultiModalProcessor(
|
|||||||
image_height=pixel_value.shape[-2],
|
image_height=pixel_value.shape[-2],
|
||||||
) for pixel_value in processed_outputs["pixel_values"]
|
) for pixel_value in processed_outputs["pixel_values"]
|
||||||
]
|
]
|
||||||
num_crops = torch.tensor([(ncols + 1) * nrows
|
num_patches = torch.tensor([(ncols + 1) * nrows
|
||||||
for ncols, nrows in tile_sizes])
|
for ncols, nrows in tile_sizes])
|
||||||
# Each image may result to masks of different sizes, so we need to
|
# Each image may result to masks of different sizes, so we need to
|
||||||
# flatten the list and later use `num_crops` to get per-image masks.
|
# later use `num_patches` to get per-image masks.
|
||||||
embed_is_patch = torch.tensor(
|
embed_is_patch = [
|
||||||
flatten_2d_lists([([True] * ncols + [False]) * nrows
|
torch.tensor(([True] * ncols + [False]) * nrows)
|
||||||
for ncols, nrows in tile_sizes]))
|
for ncols, nrows in tile_sizes
|
||||||
processed_outputs["num_crops"] = num_crops
|
]
|
||||||
|
processed_outputs["num_patches"] = num_patches
|
||||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||||
processed_outputs["feat_is_patch"] = embed_is_patch
|
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
@ -378,14 +370,10 @@ class PixtralHFMultiModalProcessor(
|
|||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
|
|
||||||
return dict(
|
return dict(
|
||||||
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
|
||||||
"image", num_crops),
|
|
||||||
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
|
||||||
"image", num_crops),
|
|
||||||
num_crops=MultiModalFieldConfig.batched("image"),
|
|
||||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
|
num_patches=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -628,27 +616,21 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
if self.config.vision_config.model_type == "pixtral":
|
if self.config.vision_config.model_type == "pixtral":
|
||||||
feat_is_patch = kwargs.pop("feat_is_patch")
|
|
||||||
if not isinstance(feat_is_patch, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of feat_is_patch. "
|
|
||||||
f"Got type: {type(feat_is_patch)}")
|
|
||||||
|
|
||||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of embed_is_patch. "
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
f"Got type: {type(embed_is_patch)}")
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
num_crops = kwargs.pop("num_crops")
|
num_patches = kwargs.pop("num_patches")
|
||||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
if not isinstance(num_patches, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of num_crops. "
|
raise ValueError("Incorrect type of num_patches. "
|
||||||
f"Got type: {type(num_crops)}")
|
f"Got type: {type(num_patches)}")
|
||||||
|
|
||||||
return PixtralHFImagePixelInputs(
|
return PixtralHFImagePixelInputs(
|
||||||
type="pixel_values_pixtral",
|
type="pixel_values_pixtral",
|
||||||
pixel_values=flatten_bn(pixel_values),
|
pixel_values=flatten_bn(pixel_values),
|
||||||
feat_is_patch=feat_is_patch,
|
|
||||||
embed_is_patch=embed_is_patch,
|
embed_is_patch=embed_is_patch,
|
||||||
num_crops=num_crops,
|
num_patches=num_patches,
|
||||||
)
|
)
|
||||||
|
|
||||||
return LlavaImagePixelInputs(
|
return LlavaImagePixelInputs(
|
||||||
@ -687,21 +669,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
|
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
|
||||||
PixtralHFVisionModel],
|
PixtralHFVisionModel],
|
||||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
|
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
) -> torch.Tensor:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
|
|
||||||
# NOTE: we skip the step to select the vision feature layer since
|
# NOTE: we skip the step to select the vision feature layer since
|
||||||
# this is already done inside the vision tower
|
# this is already done inside the vision tower
|
||||||
image_features = vision_tower(pixel_values)
|
image_features = vision_tower(pixel_values)
|
||||||
|
|
||||||
return self._select_image_features(
|
def select_features(leaf: torch.Tensor):
|
||||||
image_features,
|
return self._select_image_features(
|
||||||
strategy=self.config.vision_feature_select_strategy,
|
leaf,
|
||||||
|
strategy=self.config.vision_feature_select_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(
|
||||||
|
Union[torch.Tensor, tuple[torch.Tensor, ...]],
|
||||||
|
json_map_leaves(select_features, image_features),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _process_image_pixels(
|
def _process_image_pixels(
|
||||||
self,
|
self,
|
||||||
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
|
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
|
||||||
) -> torch.Tensor:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
assert self.vision_tower is not None
|
assert self.vision_tower is not None
|
||||||
|
|
||||||
pixel_values = inputs["pixel_values"]
|
pixel_values = inputs["pixel_values"]
|
||||||
@ -731,45 +718,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
def _get_mm_embeds(
|
def _get_mm_embeds(
|
||||||
self,
|
self,
|
||||||
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
|
features: torch.Tensor, # Shape: (num_patch, d)
|
||||||
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
|
num_patches: torch.Tensor, # Shape: (num_images,)
|
||||||
num_crops: torch.Tensor, # Shape: (num_images,)
|
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
|
||||||
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
|
) -> tuple[torch.Tensor, ...]:
|
||||||
) -> list[torch.Tensor]:
|
|
||||||
"""Scatter the patch features into a contiguous tensor that corresponds
|
"""Scatter the patch features into a contiguous tensor that corresponds
|
||||||
to the embedding tokens defined by the multimodal processor.
|
to the embedding tokens defined by the multimodal processor.
|
||||||
|
|
||||||
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
|
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
|
||||||
"""
|
"""
|
||||||
|
# Insert columns of nan values according to `embed_is_patch`. This work
|
||||||
# Insert columns of nan values according to `feat_is_patch`. This work
|
|
||||||
# ideally should be done in `_process_image_input`, but
|
# ideally should be done in `_process_image_input`, but
|
||||||
# `_process_image_input` is used in both V0 and V1 path. It's safer to
|
# `_process_image_input` is used in both V0 and V1 path. It's safer to
|
||||||
# put the logic here.
|
# put the logic here.
|
||||||
# FIXME: Move this logic to `_process_image_input` when v0 is
|
# FIXME: Move this logic to `_process_image_input` when v0 is
|
||||||
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
|
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
|
||||||
feat_is_patch = feat_is_patch.view(-1)
|
num_patches_per_image: list[int] = num_patches.tolist()
|
||||||
embed_is_patch = embed_is_patch.view(-1)
|
|
||||||
expanded_embedding = torch.full(
|
|
||||||
(sum(num_crops), *features.shape[1:]),
|
|
||||||
torch.nan,
|
|
||||||
dtype=features.dtype).to(features.device)
|
|
||||||
expanded_embedding[feat_is_patch] = features
|
|
||||||
|
|
||||||
num_crops_per_image = num_crops.tolist()
|
embeds_flat = features.new_full(
|
||||||
feats_per_image = expanded_embedding.split(num_crops_per_image)
|
(sum(num_patches_per_image), *features.shape[1:]),
|
||||||
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
|
fill_value=torch.nan,
|
||||||
|
)
|
||||||
|
embeds_flat[embed_is_patch.view(-1)] = features
|
||||||
|
|
||||||
embed_dim = expanded_embedding.shape[-1]
|
return embeds_flat.split(num_patches_per_image)
|
||||||
num_embeds = embed_is_patch.shape[0]
|
|
||||||
|
|
||||||
embeds_in_batch = list[torch.Tensor]()
|
|
||||||
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
|
|
||||||
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
|
|
||||||
embeds[embed_is_patch] = feats[f_is_patch]
|
|
||||||
embeds_in_batch.append(embeds)
|
|
||||||
|
|
||||||
return embeds_in_batch
|
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
@ -784,12 +756,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
# The path is used for pixtral (V0 only) and llava (V0/V1)
|
# The path is used for pixtral (V0 only) and llava (V0/V1)
|
||||||
return vision_embeddings
|
return vision_embeddings
|
||||||
|
|
||||||
nested_emb = [
|
return flatten_2d_lists(
|
||||||
self._get_mm_embeds(*args) for args in zip(
|
self._get_mm_embeds(*args) for args in zip(
|
||||||
vision_embeddings, image_input["feat_is_patch"],
|
vision_embeddings,
|
||||||
image_input["num_crops"], image_input["embed_is_patch"])
|
image_input["num_patches"],
|
||||||
]
|
image_input["embed_is_patch"],
|
||||||
return flatten_2d_lists(nested_emb)
|
))
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -805,9 +777,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
)
|
)
|
||||||
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, cast(NestedTensors,
|
input_ids,
|
||||||
patch_embeddings),
|
inputs_embeds,
|
||||||
self.config.image_token_index)
|
cast(NestedTensors, patch_embeddings),
|
||||||
|
self.config.image_token_index,
|
||||||
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -1585,15 +1585,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
|
|
||||||
image_features = self._process_image_input(image_input)
|
image_features = self._process_image_input(image_input)
|
||||||
|
|
||||||
nested_embeds = [
|
return flatten_2d_lists(
|
||||||
self._get_mm_embeds(*args) for args in zip(
|
self._get_mm_embeds(*args) for args in zip(
|
||||||
image_features,
|
image_features,
|
||||||
image_input["feat_is_patch"],
|
image_input["feat_is_patch"],
|
||||||
image_input["num_crops"],
|
image_input["num_crops"],
|
||||||
image_input["embed_is_patch"],
|
image_input["embed_is_patch"],
|
||||||
)
|
))
|
||||||
]
|
|
||||||
return flatten_2d_lists(nested_embeds)
|
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
|
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
|
||||||
TypedDict, Union)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -17,7 +16,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
|||||||
from vllm.multimodal.parse import MultiModalDataItems
|
from vllm.multimodal.parse import MultiModalDataItems
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, PromptIndexTargets,
|
BaseProcessingInfo, PromptIndexTargets,
|
||||||
PromptInsertion, PromptReplacement,
|
PromptInsertion, PromptUpdate,
|
||||||
PromptUpdateDetails)
|
PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -144,7 +143,7 @@ class PaliGemmaMultiModalProcessor(
|
|||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> Sequence[PromptUpdate]:
|
||||||
hf_config = self.info.get_hf_config()
|
hf_config = self.info.get_hf_config()
|
||||||
image_token_id = hf_config.image_token_index
|
image_token_id = hf_config.image_token_index
|
||||||
|
|
||||||
|
@ -1,26 +1,28 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import List, Optional, Set, Tuple, Union
|
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mistral_common.protocol.instruct.messages import ImageChunk
|
from mistral_common.protocol.instruct.messages import ImageChunk
|
||||||
|
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import PixtralVisionConfig
|
from transformers import PixtralVisionConfig, TensorType
|
||||||
|
from transformers.image_utils import ImageInput
|
||||||
from transformers.models.pixtral.image_processing_pixtral import (
|
from transformers.models.pixtral.image_processing_pixtral import (
|
||||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
|
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
|
||||||
from transformers.models.pixtral.modeling_pixtral import (
|
from transformers.models.pixtral.modeling_pixtral import (
|
||||||
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
|
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
|
||||||
|
from transformers.tokenization_utils_base import TextInput
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
from vllm.jsontree import JSONTree, json_map_leaves
|
||||||
InputContext, token_inputs)
|
|
||||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
@ -31,13 +33,20 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
|
||||||
from vllm.multimodal.utils import consecutive_placeholder_ranges
|
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||||
from vllm.sequence import IntermediateTensors, SequenceData
|
MultiModalDataItems)
|
||||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
BaseProcessingInfo, PromptReplacement,
|
||||||
|
PromptUpdate)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||||
|
cached_tokenizer_from_config)
|
||||||
|
from vllm.utils import flatten_2d_lists
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (init_vllm_registered_model, maybe_prefix,
|
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
||||||
|
|
||||||
@ -48,132 +57,275 @@ except ImportError:
|
|||||||
USE_XFORMERS_OPS = False
|
USE_XFORMERS_OPS = False
|
||||||
|
|
||||||
|
|
||||||
def get_max_pixtral_image_tokens(ctx: InputContext):
|
class PixtralImagePixelInputs(TypedDict):
|
||||||
tokenizer = cached_tokenizer_from_config(ctx.model_config)
|
type: Literal["pixel_values"]
|
||||||
mm_encoder = tokenizer.instruct.mm_encoder
|
|
||||||
|
|
||||||
image_config = mm_encoder.mm_config if hasattr(
|
images: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
mm_encoder, "mm_config") else mm_encoder.image_config
|
|
||||||
|
|
||||||
max_image_size = image_config.max_image_size
|
|
||||||
image_patch_size = image_config.image_patch_size
|
|
||||||
|
|
||||||
return ((max_image_size // image_patch_size)**2)
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
|
|
||||||
mm_counts: Mapping[str, int]):
|
|
||||||
tokenizer = cached_tokenizer_from_config(ctx.model_config)
|
|
||||||
|
|
||||||
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
|
||||||
image_token_id = mm_encoder.special_ids.img
|
|
||||||
|
|
||||||
mm_config = ctx.get_mm_config()
|
|
||||||
num_images = mm_config.get_limit_per_prompt("image")
|
|
||||||
|
|
||||||
# dummy size
|
|
||||||
size = 256
|
|
||||||
image = Image.new("RGB", (size, size), color=0)
|
|
||||||
|
|
||||||
encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image))
|
|
||||||
image_feature_size = len(encoding.tokens)
|
|
||||||
num_image_tokens = image_feature_size * num_images
|
|
||||||
seq_data = SequenceData.from_prompt_token_counts(
|
|
||||||
(image_token_id, num_image_tokens),
|
|
||||||
(0, seq_len - num_image_tokens),
|
|
||||||
)
|
|
||||||
|
|
||||||
mm_data = {"image": num_images * [image]}
|
|
||||||
mm_placeholders = {
|
|
||||||
"image":
|
|
||||||
consecutive_placeholder_ranges(num_items=num_images,
|
|
||||||
item_size=image_feature_size)
|
|
||||||
}
|
|
||||||
return DummyData(seq_data, mm_data, mm_placeholders)
|
|
||||||
|
|
||||||
|
|
||||||
def input_mapper_for_pixtral(ctx: InputContext,
|
|
||||||
data: object) -> MultiModalKwargs:
|
|
||||||
"""Maps the input data to its MultiModalKwargs (if any).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ctx: Context of the loaded model.
|
|
||||||
data: data potentially containing PIL images to be processed
|
|
||||||
and mapped to `images`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MultiModalKwargs containing the stacked normalized images tensor or
|
|
||||||
image embeddings.
|
|
||||||
"""
|
"""
|
||||||
tokenizer = cached_tokenizer_from_config(ctx.model_config)
|
Shape: `(batch_size * num_images, num_channels, image_width, image_height)`
|
||||||
|
|
||||||
data_list = data if isinstance(data, list) else [data]
|
The result of stacking :attr:`ImageEncoding.tokens` from each prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
images = []
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
image_tokens_list = []
|
"""
|
||||||
for image_data in data_list:
|
A boolean mask indicating which image embeddings correspond
|
||||||
image = ImageChunk(image=image_data)
|
to patch tokens.
|
||||||
encoding = tokenizer.instruct.mm_encoder(image)
|
|
||||||
image = torch.from_numpy(encoding.image).to(dtype=torch.float16)
|
Shape: `(batch_size, num_images, num_embeds)`
|
||||||
images.append(image)
|
"""
|
||||||
image_tokens_list.append(encoding.tokens)
|
|
||||||
|
|
||||||
image_tokens = torch.tensor([
|
num_patches: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
token_id for image_tokens in image_tokens_list
|
"""Shape: `(batch_size, num_images)`"""
|
||||||
for token_id in image_tokens
|
|
||||||
])
|
|
||||||
return MultiModalKwargs({"images": images, "image_tokens": image_tokens})
|
|
||||||
|
|
||||||
|
|
||||||
def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
|
class PixtralProcessorAdapter:
|
||||||
multi_modal_data = inputs.get("multi_modal_data")
|
"""
|
||||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
Provide a HF-compatible interface for
|
||||||
return inputs
|
:class:`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
|
||||||
|
"""
|
||||||
|
|
||||||
prompt_token_ids = inputs.get("prompt_token_ids")
|
def __init__(self, tokenizer: MistralTokenizer) -> None:
|
||||||
prompt = inputs.get("prompt")
|
super().__init__()
|
||||||
tokenizer = cached_tokenizer_from_config(ctx.model_config)
|
|
||||||
|
|
||||||
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
self.tokenizer = tokenizer
|
||||||
image_token_id = mm_encoder.special_ids.img
|
|
||||||
image_break_id = mm_encoder.special_ids.img_break
|
|
||||||
image_end_id = mm_encoder.special_ids.img_end
|
|
||||||
|
|
||||||
if image_token_id not in inputs['prompt_token_ids']:
|
@property
|
||||||
raise ValueError(
|
def image_processor(self) -> ImageEncoder:
|
||||||
f"You've passed {inputs=} without {image_token_id=}"
|
image_encoder = self.tokenizer.instruct.mm_encoder
|
||||||
" Make sure to process your input via mistral_common's"
|
assert isinstance(image_encoder, ImageEncoder)
|
||||||
" tokenizer or pass a chat completion request. For more"
|
return image_encoder
|
||||||
" For more info, see: "
|
|
||||||
"https://github.com/vllm-project/vllm/issues/8411.")
|
|
||||||
|
|
||||||
# Get precise tracking of placeholder positions
|
@cached_property
|
||||||
placeholder_ranges = []
|
def image_break_id(self) -> int:
|
||||||
curr_offset = -1
|
return self.image_processor.special_ids.img_break
|
||||||
curr_length = 0
|
|
||||||
for i in range(len(prompt_token_ids)):
|
@cached_property
|
||||||
if prompt_token_ids[i] in (image_token_id, image_break_id):
|
def image_token_id(self) -> int:
|
||||||
if curr_offset < 0:
|
return self.image_processor.special_ids.img
|
||||||
curr_offset = i
|
|
||||||
curr_length += 1
|
@cached_property
|
||||||
elif prompt_token_ids[i] == image_end_id:
|
def image_end_id(self) -> int:
|
||||||
curr_length += 1
|
return self.image_processor.special_ids.img_end
|
||||||
placeholder_ranges.append(
|
|
||||||
PlaceholderRange(offset=curr_offset, length=curr_length))
|
@cached_property
|
||||||
curr_offset = -1
|
def image_size(self) -> int:
|
||||||
curr_length = 0
|
return self.image_processor.mm_config.max_image_size
|
||||||
else:
|
|
||||||
pass
|
@cached_property
|
||||||
return token_inputs(prompt=prompt,
|
def patch_size(self) -> int:
|
||||||
prompt_token_ids=prompt_token_ids,
|
return self.image_processor.mm_config.image_patch_size
|
||||||
multi_modal_data=multi_modal_data,
|
|
||||||
multi_modal_placeholders={"image": placeholder_ranges})
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
||||||
|
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Mapping[str, NestedTensors]:
|
||||||
|
if text is None:
|
||||||
|
text = []
|
||||||
|
if not isinstance(text, list):
|
||||||
|
text = [text]
|
||||||
|
if images is None:
|
||||||
|
images = []
|
||||||
|
if not isinstance(images, list):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not images:
|
||||||
|
input_ids = self.tokenizer(text).input_ids
|
||||||
|
|
||||||
|
return {"input_ids": torch.tensor(input_ids)}
|
||||||
|
|
||||||
|
# Allow dummy text, which is used for profiling as well as token inputs
|
||||||
|
if any(len(t) > 0 for t in text):
|
||||||
|
raise ValueError(
|
||||||
|
"You've passed text inputs instead of token inputs. "
|
||||||
|
"Make sure to process your input via `mistral_common`'s "
|
||||||
|
"tokenizer or pass a chat completion request. "
|
||||||
|
"For more info, see: "
|
||||||
|
"https://github.com/vllm-project/vllm/issues/8411.")
|
||||||
|
|
||||||
|
image_token_id = self.image_token_id
|
||||||
|
|
||||||
|
images_processed = list[torch.Tensor]()
|
||||||
|
images_tokens = list[torch.Tensor]()
|
||||||
|
images_embed_is_patch = list[torch.Tensor]()
|
||||||
|
images_num_patches = list[int]()
|
||||||
|
|
||||||
|
for image in images:
|
||||||
|
image_inputs = self.image_processor(ImageChunk(image=image))
|
||||||
|
|
||||||
|
image_processed = torch.tensor(image_inputs.image)
|
||||||
|
image_tokens = torch.tensor(image_inputs.tokens)
|
||||||
|
|
||||||
|
images_processed.append(image_processed)
|
||||||
|
images_tokens.append(image_tokens)
|
||||||
|
images_embed_is_patch.append(image_tokens == image_token_id)
|
||||||
|
images_num_patches.append(len(image_tokens))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
||||||
|
"images": images_processed,
|
||||||
|
"embed_is_patch": images_embed_is_patch,
|
||||||
|
"num_patches": torch.tensor(images_num_patches),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
|
class PixtralProcessingInfo(BaseProcessingInfo):
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
|
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
|
def get_tokenizer(self) -> MistralTokenizer:
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
|
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
|
||||||
|
if not isinstance(tokenizer, MistralTokenizer):
|
||||||
|
raise ValueError("This model requires `--tokenizer-mode mistral`")
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def get_hf_processor(self) -> PixtralProcessorAdapter:
|
||||||
|
return PixtralProcessorAdapter(self.get_tokenizer())
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": None}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> Mapping[str, int]:
|
||||||
|
return {"image": self.get_max_image_tokens()}
|
||||||
|
|
||||||
|
def get_vision_config(
|
||||||
|
self,
|
||||||
|
processor: Optional[PixtralProcessorAdapter] = None,
|
||||||
|
):
|
||||||
|
if processor is None:
|
||||||
|
processor = self.get_hf_processor()
|
||||||
|
|
||||||
|
return PixtralVisionConfig(
|
||||||
|
image_size=processor.image_size,
|
||||||
|
patch_size=processor.patch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_num_image_tokens(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
processor: Optional[PixtralProcessorAdapter] = None,
|
||||||
|
) -> int:
|
||||||
|
if processor is None:
|
||||||
|
processor = self.get_hf_processor()
|
||||||
|
|
||||||
|
ncols, nrows = processor.image_processor._image_to_num_tokens(
|
||||||
|
Image.new("RGB", (image_width, image_height)))
|
||||||
|
|
||||||
|
return (ncols + 1) * nrows
|
||||||
|
|
||||||
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
|
image_processor = self.get_hf_processor().image_processor
|
||||||
|
max_image_size = image_processor.mm_config.max_image_size
|
||||||
|
|
||||||
|
return ImageSize(width=max_image_size, height=max_image_size)
|
||||||
|
|
||||||
|
def get_max_image_tokens(self) -> int:
|
||||||
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
|
return self.get_num_image_tokens(
|
||||||
|
image_width=target_width,
|
||||||
|
image_height=target_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||||
|
|
||||||
|
def get_dummy_processor_inputs(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> ProcessorInputs:
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
|
target_width, target_height = \
|
||||||
|
self.info.get_image_size_with_most_features()
|
||||||
|
|
||||||
|
mm_data = {
|
||||||
|
"image":
|
||||||
|
self._get_dummy_images(width=target_width,
|
||||||
|
height=target_height,
|
||||||
|
num_images=num_images)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ProcessorInputs(
|
||||||
|
prompt_text="",
|
||||||
|
mm_data=mm_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||||
|
):
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: Mapping[str, NestedTensors],
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return dict(
|
||||||
|
images=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
|
num_patches=MultiModalFieldConfig.batched("image"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> Sequence[PromptUpdate]:
|
||||||
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
|
||||||
|
image_break_id = processor.image_break_id
|
||||||
|
image_token_id = processor.image_token_id
|
||||||
|
image_end_id = processor.image_end_id
|
||||||
|
|
||||||
|
def get_replacement(item_idx: int):
|
||||||
|
images = mm_items.get_items("image", ImageProcessorItems)
|
||||||
|
image_size = images.get_image_size(item_idx)
|
||||||
|
|
||||||
|
ncols, nrows = processor.image_processor._image_to_num_tokens(
|
||||||
|
Image.new("RGB", (image_size.width, image_size.height)))
|
||||||
|
|
||||||
|
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
||||||
|
tokens[-1] = image_end_id
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality="image",
|
||||||
|
target="", # Never match the prompt (see below note)
|
||||||
|
replacement=get_replacement,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _cached_apply_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, list[int]],
|
||||||
|
mm_data_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||||
|
prompt_ids, mm_kwargs, _ = super()._cached_apply_hf_processor(
|
||||||
|
prompt=prompt,
|
||||||
|
mm_data_items=mm_data_items,
|
||||||
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: The tokens are already inserted by the chat template
|
||||||
|
return prompt_ids, mm_kwargs, True
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,
|
||||||
|
info=PixtralProcessingInfo,
|
||||||
|
dummy_inputs=PixtralDummyInputsBuilder)
|
||||||
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsPP):
|
SupportsPP):
|
||||||
|
|
||||||
@ -191,13 +343,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
if key in dataclass_fields
|
if key in dataclass_fields
|
||||||
}
|
}
|
||||||
|
|
||||||
if not ("image_break_token_id" in vision_args
|
|
||||||
and "image_end_token_id" in vision_args):
|
|
||||||
raise ValueError(
|
|
||||||
"'image_break_token_id' and 'image_end_token_id' not found "
|
|
||||||
"in the vision_encoder arguments. Please download the latest "
|
|
||||||
"version of 'params.json' from the model repository.")
|
|
||||||
|
|
||||||
self.vision_args = VisionEncoderArgs(**vision_args)
|
self.vision_args = VisionEncoderArgs(**vision_args)
|
||||||
|
|
||||||
# init MistralForCausalLM
|
# init MistralForCausalLM
|
||||||
@ -221,36 +366,92 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
return get_sampler()
|
return get_sampler()
|
||||||
|
|
||||||
|
def _parse_and_validate_image_input(
|
||||||
|
self, **kwargs: object) -> Optional[PixtralImagePixelInputs]:
|
||||||
|
images = kwargs.pop("images", None)
|
||||||
|
if images is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(images, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of images. "
|
||||||
|
f"Got type: {type(images)}")
|
||||||
|
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
|
num_patches = kwargs.pop("num_patches")
|
||||||
|
if not isinstance(num_patches, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of num_patches. "
|
||||||
|
f"Got type: {type(num_patches)}")
|
||||||
|
|
||||||
|
return PixtralImagePixelInputs(
|
||||||
|
type="pixel_values",
|
||||||
|
images=flatten_bn(images),
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
|
num_patches=num_patches,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_image_input(
|
||||||
|
self,
|
||||||
|
image_input: PixtralImagePixelInputs,
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
images = image_input["images"]
|
||||||
|
|
||||||
|
image_features = self.vision_encoder(images)
|
||||||
|
feature_sizes = [
|
||||||
|
image_feature.shape[0] for image_feature in image_features
|
||||||
|
]
|
||||||
|
|
||||||
|
image_embeds = self.vision_language_adapter(torch.cat(image_features))
|
||||||
|
image_embeds = torch.split(image_embeds, feature_sizes)
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
|
def _get_mm_embeds(
|
||||||
|
self,
|
||||||
|
features: torch.Tensor, # Shape: (num_patch, d)
|
||||||
|
num_patches: torch.Tensor, # Shape: (num_images,)
|
||||||
|
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
"""Scatter the patch features into a contiguous tensor that corresponds
|
||||||
|
to the embedding tokens defined by the multimodal processor.
|
||||||
|
|
||||||
|
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
|
||||||
|
"""
|
||||||
|
# Insert columns of nan values according to `embed_is_patch`. This work
|
||||||
|
# ideally should be done in `_process_image_input`, but
|
||||||
|
# `_process_image_input` is used in both V0 and V1 path. It's safer to
|
||||||
|
# put the logic here.
|
||||||
|
# FIXME: Move this logic to `_process_image_input` when v0 is
|
||||||
|
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
|
||||||
|
num_patches_per_image: list[int] = num_patches.tolist()
|
||||||
|
|
||||||
|
embeds_flat = features.new_full(
|
||||||
|
(sum(num_patches_per_image), *features.shape[1:]),
|
||||||
|
fill_value=torch.nan,
|
||||||
|
)
|
||||||
|
embeds_flat[embed_is_patch.view(-1)] = features
|
||||||
|
|
||||||
|
return embeds_flat.split(num_patches_per_image)
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
image_input, image_tokens = self._parse_and_validate_image_input(
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
**kwargs)
|
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
image_features = self._process_image_input(image_input)
|
||||||
|
|
||||||
# NOTE: We patch the outputs of the vision encoder with embeddings
|
if kwargs.get("v0_path", False):
|
||||||
# from `[IMG_BREAK]` and `[IMG_END]` tokens.
|
return image_features
|
||||||
image_embeds = self.language_model.get_input_embeddings(image_tokens)
|
|
||||||
image_token_mask = image_tokens == self.vision_args.image_token_id
|
|
||||||
image_embeds[image_token_mask] = vision_embeddings
|
|
||||||
|
|
||||||
# NOTE: Image embeddings are split into separate tensors for each image
|
return flatten_2d_lists(
|
||||||
# by the indices of `[IMG_END]` token.
|
self._get_mm_embeds(*args) for args in zip(
|
||||||
image_end_mask = image_tokens == self.vision_args.image_end_token_id
|
image_features,
|
||||||
split_indices = torch.where(image_end_mask)[0] + 1
|
image_input["num_patches"],
|
||||||
if len(split_indices) <= 1:
|
image_input["embed_is_patch"],
|
||||||
# Do not split, return as tensor of shape [1, fs, hs]
|
))
|
||||||
return image_embeds.unsqueeze(0)
|
|
||||||
|
|
||||||
# If the last split index is the last index in image_tokens, we
|
|
||||||
# ignore it to avoid empty split tensor
|
|
||||||
if split_indices[-1] == len(image_tokens):
|
|
||||||
split_indices = split_indices[:-1]
|
|
||||||
|
|
||||||
image_embeds = image_embeds.tensor_split(split_indices.cpu())
|
|
||||||
return image_embeds
|
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -259,12 +460,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
if multimodal_embeddings is not None:
|
if multimodal_embeddings is not None:
|
||||||
|
# Extract the patch tokens
|
||||||
|
patch_embeddings = json_map_leaves(
|
||||||
|
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
||||||
|
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
||||||
|
)
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, multimodal_embeddings, [
|
input_ids,
|
||||||
self.vision_args.image_token_id,
|
inputs_embeds,
|
||||||
self.vision_args.image_break_token_id,
|
cast(NestedTensors, patch_embeddings),
|
||||||
self.vision_args.image_end_token_id,
|
self.vision_args.image_token_id,
|
||||||
])
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -275,14 +481,14 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
"""Run forward pass for pixtral.
|
"""Run forward pass for pixtral."""
|
||||||
"""
|
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
|
|
||||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
|
kwargs.update({"v0_path": True})
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||||
vision_embeddings)
|
vision_embeddings)
|
||||||
@ -295,47 +501,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
|
||||||
self,
|
|
||||||
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
|
|
||||||
torch.Tensor]] = None,
|
|
||||||
image_tokens: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]:
|
|
||||||
if images is None:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
if isinstance(images, torch.Tensor):
|
|
||||||
# if passed as batch take all images
|
|
||||||
N, B, C, W, H = images.shape
|
|
||||||
images = images.reshape(N * B, C, W, H)
|
|
||||||
images = [images[i] for i in range(images.size(0))]
|
|
||||||
elif isinstance(images, list):
|
|
||||||
# if passed as list flatten lists of tensors
|
|
||||||
flatten_images = []
|
|
||||||
for imgs_per_req in images:
|
|
||||||
imgs_per_req = [
|
|
||||||
imgs_per_req[i] for i in range(imgs_per_req.size(0))
|
|
||||||
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
|
|
||||||
|
|
||||||
flatten_images.extend(imgs_per_req)
|
|
||||||
|
|
||||||
images = flatten_images
|
|
||||||
|
|
||||||
if isinstance(image_tokens, torch.Tensor):
|
|
||||||
# image_tokens are batched
|
|
||||||
image_tokens = image_tokens.flatten()
|
|
||||||
elif isinstance(image_tokens, list):
|
|
||||||
# image_tokens are of different lengths thus passed as a list
|
|
||||||
image_tokens = torch.cat(image_tokens)
|
|
||||||
|
|
||||||
assert image_tokens.dim() == 1
|
|
||||||
|
|
||||||
return images, image_tokens
|
|
||||||
|
|
||||||
def _process_image_input(self,
|
|
||||||
image_input: List[torch.Tensor]) -> torch.Tensor:
|
|
||||||
return self.vision_language_adapter(self.vision_encoder(image_input))
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -400,8 +565,6 @@ class VisionEncoderArgs:
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
rope_theta: float # for rope-2D
|
rope_theta: float # for rope-2D
|
||||||
image_token_id: int
|
image_token_id: int
|
||||||
image_break_token_id: int
|
|
||||||
image_end_token_id: int
|
|
||||||
adapter_bias: bool = True
|
adapter_bias: bool = True
|
||||||
|
|
||||||
|
|
||||||
@ -637,9 +800,13 @@ class VisionTransformer(nn.Module):
|
|||||||
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
|
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
|
||||||
]
|
]
|
||||||
|
|
||||||
|
patch_embeds = [
|
||||||
|
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
|
||||||
|
]
|
||||||
|
embed_sizes = [p.shape[1] for p in patch_embeds]
|
||||||
|
|
||||||
# flatten to a single sequence
|
# flatten to a single sequence
|
||||||
patch_embeds = torch.cat(
|
patch_embeds = torch.cat(patch_embeds, dim=1)
|
||||||
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
|
|
||||||
patch_embeds = self.ln_pre(patch_embeds)
|
patch_embeds = self.ln_pre(patch_embeds)
|
||||||
|
|
||||||
# positional embeddings
|
# positional embeddings
|
||||||
@ -655,8 +822,8 @@ class VisionTransformer(nn.Module):
|
|||||||
"with the Mistral format")
|
"with the Mistral format")
|
||||||
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
|
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
|
||||||
|
|
||||||
# remove batch dimension of the single sequence
|
# squeeze dim 0 and split into separate tensors for each image
|
||||||
return out.squeeze(0)
|
return torch.split(out.squeeze(0), embed_sizes)
|
||||||
|
|
||||||
|
|
||||||
class VisionLanguageAdapter(nn.Module):
|
class VisionLanguageAdapter(nn.Module):
|
||||||
@ -978,9 +1145,9 @@ class PixtralHFVisionModel(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: List[torch.Tensor],
|
pixel_values: list[torch.Tensor],
|
||||||
feature_sample_layers: Optional[list[int]] = None,
|
feature_sample_layers: Optional[list[int]] = None,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
pixel_values: Each image to be processed will be a separate tensor
|
pixel_values: Each image to be processed will be a separate tensor
|
||||||
@ -1039,8 +1206,7 @@ class PixtralHFVisionModel(nn.Module):
|
|||||||
self.config.num_hidden_layers)
|
self.config.num_hidden_layers)
|
||||||
|
|
||||||
# squeeze dim 0 and split into separate tensors for each image
|
# squeeze dim 0 and split into separate tensors for each image
|
||||||
out = torch.split(torch.squeeze(out), embed_sizes)
|
return torch.split(out.squeeze(0), embed_sizes)
|
||||||
return out
|
|
||||||
|
|
||||||
# (TODO) Add prefix argument for filtering out weights to be loaded
|
# (TODO) Add prefix argument for filtering out weights to be loaded
|
||||||
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
||||||
|
@ -77,7 +77,9 @@ class PromptIndexTargets:
|
|||||||
else:
|
else:
|
||||||
if isinstance(prefix, str):
|
if isinstance(prefix, str):
|
||||||
# Make both `list[int]`
|
# Make both `list[int]`
|
||||||
prefix = encode_tokens(tokenizer, prefix)
|
prefix = encode_tokens(tokenizer,
|
||||||
|
prefix,
|
||||||
|
add_special_tokens=False)
|
||||||
|
|
||||||
match_idx = len(prefix)
|
match_idx = len(prefix)
|
||||||
return match_idx if prompt[:match_idx] == prefix else None
|
return match_idx if prompt[:match_idx] == prefix else None
|
||||||
@ -318,7 +320,7 @@ def _cached_encode(
|
|||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
text: str,
|
text: str,
|
||||||
*,
|
*,
|
||||||
add_special_tokens: bool = False,
|
add_special_tokens: Optional[bool] = None,
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
return encode_tokens(tokenizer,
|
return encode_tokens(tokenizer,
|
||||||
text,
|
text,
|
||||||
@ -330,7 +332,7 @@ def _cached_decode(
|
|||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
token_ids: tuple[int, ...],
|
token_ids: tuple[int, ...],
|
||||||
*,
|
*,
|
||||||
skip_special_tokens: bool = False,
|
skip_special_tokens: Optional[bool] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
return decode_tokens(tokenizer,
|
return decode_tokens(tokenizer,
|
||||||
list(token_ids),
|
list(token_ids),
|
||||||
@ -395,7 +397,9 @@ class _BoundPromptSequence:
|
|||||||
def token_ids(self) -> list[int]:
|
def token_ids(self) -> list[int]:
|
||||||
if self._token_ids is None:
|
if self._token_ids is None:
|
||||||
assert self._text is not None
|
assert self._text is not None
|
||||||
self._token_ids = _cached_encode(self.tokenizer, self._text)
|
self._token_ids = _cached_encode(self.tokenizer,
|
||||||
|
self._text,
|
||||||
|
add_special_tokens=False)
|
||||||
|
|
||||||
return self._token_ids
|
return self._token_ids
|
||||||
|
|
||||||
@ -1046,7 +1050,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptUpdate]:
|
) -> Sequence[PromptUpdate]:
|
||||||
"""
|
"""
|
||||||
Given the original multi-modal items for this modality
|
Given the original multi-modal items for this modality
|
||||||
and HF-processed data, output the updates to perform.
|
and HF-processed data, output the updates to perform.
|
||||||
|
@ -34,13 +34,20 @@ def decode_tokens(
|
|||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
token_ids: list[int],
|
token_ids: list[int],
|
||||||
*,
|
*,
|
||||||
skip_special_tokens: bool = False,
|
skip_special_tokens: Optional[bool] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Backend-agnostic equivalent of HF's
|
Backend-agnostic equivalent of HF's
|
||||||
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
|
:code:`tokenizer.decode(token_ids, ...)`.
|
||||||
|
|
||||||
|
:code:`skip_special_tokens=None` means to use the backend's default
|
||||||
|
settings.
|
||||||
"""
|
"""
|
||||||
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
if skip_special_tokens is not None:
|
||||||
|
return tokenizer.decode(token_ids,
|
||||||
|
skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
return tokenizer.decode(token_ids)
|
||||||
|
|
||||||
|
|
||||||
def encode_tokens(
|
def encode_tokens(
|
||||||
@ -51,10 +58,14 @@ def encode_tokens(
|
|||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Backend-agnostic equivalent of HF's
|
Backend-agnostic equivalent of HF's
|
||||||
:code:`tokenizer.encode(text, add_special_tokens=...)`.
|
:code:`tokenizer.encode(text, ...)`.
|
||||||
|
|
||||||
|
:code:`add_special_tokens=None` means to use the backend's default
|
||||||
|
settings.
|
||||||
"""
|
"""
|
||||||
if add_special_tokens is not None:
|
if add_special_tokens is not None:
|
||||||
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
||||||
|
|
||||||
return tokenizer.encode(text)
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
|
|
||||||
|
@ -845,7 +845,7 @@ def is_list_of(
|
|||||||
assert_never(check)
|
assert_never(check)
|
||||||
|
|
||||||
|
|
||||||
def flatten_2d_lists(lists: list[list[T]]) -> list[T]:
|
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
|
||||||
"""Flatten a list of lists to a single list."""
|
"""Flatten a list of lists to a single list."""
|
||||||
return [item for sublist in lists for item in sublist]
|
return [item for sublist in lists for item in sublist]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user