[Model] Support Pixtral models in the HF Transformers format (#9036)
This commit is contained in:
parent
67a7e5ef38
commit
3921a2f29e
@ -437,7 +437,7 @@ Text Generation
|
|||||||
* - :code:`PixtralForConditionalGeneration`
|
* - :code:`PixtralForConditionalGeneration`
|
||||||
- Pixtral
|
- Pixtral
|
||||||
- T + I\ :sup:`+`
|
- T + I\ :sup:`+`
|
||||||
- :code:`mistralai/Pixtral-12B-2409`
|
- :code:`mistralai/Pixtral-12B-2409`, :code:`mistral-community/pixtral-12b` etc.
|
||||||
-
|
-
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`QWenLMHeadModel`
|
* - :code:`QWenLMHeadModel`
|
||||||
|
@ -277,6 +277,22 @@ def run_qwen2_vl(question: str, modality: str):
|
|||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
# Pixtral HF-format
|
||||||
|
def run_pixtral_hf(question: str, modality: str):
|
||||||
|
assert modality == "image"
|
||||||
|
|
||||||
|
model_name = "mistral-community/pixtral-12b"
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
max_model_len=8192,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = f"<s>[INST]{question}\n[IMG][/INST]"
|
||||||
|
stop_token_ids = None
|
||||||
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# LLama 3.2
|
# LLama 3.2
|
||||||
def run_mllama(question: str, modality: str):
|
def run_mllama(question: str, modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
@ -347,6 +363,7 @@ model_example_map = {
|
|||||||
"NVLM_D": run_nvlm_d,
|
"NVLM_D": run_nvlm_d,
|
||||||
"qwen_vl": run_qwen_vl,
|
"qwen_vl": run_qwen_vl,
|
||||||
"qwen2_vl": run_qwen2_vl,
|
"qwen2_vl": run_qwen2_vl,
|
||||||
|
"pixtral_hf": run_pixtral_hf,
|
||||||
"mllama": run_mllama,
|
"mllama": run_mllama,
|
||||||
"molmo": run_molmo,
|
"molmo": run_molmo,
|
||||||
"glm4v": run_glm4v,
|
"glm4v": run_glm4v,
|
||||||
|
@ -264,6 +264,8 @@ _ACTIVATION_REGISTRY = LazyDict({
|
|||||||
lambda: nn.ReLU(),
|
lambda: nn.ReLU(),
|
||||||
"relu2":
|
"relu2":
|
||||||
lambda: ReLUSquaredActivation(),
|
lambda: ReLUSquaredActivation(),
|
||||||
|
"silu":
|
||||||
|
lambda: nn.SiLU(),
|
||||||
"quick_gelu":
|
"quick_gelu":
|
||||||
lambda: QuickGELU(),
|
lambda: QuickGELU(),
|
||||||
})
|
})
|
||||||
|
@ -5,7 +5,8 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
|
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
|
||||||
|
SiglipVisionConfig)
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
@ -22,6 +23,10 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
|||||||
dummy_seq_data_for_clip, get_max_clip_image_tokens,
|
dummy_seq_data_for_clip, get_max_clip_image_tokens,
|
||||||
input_processor_for_clip)
|
input_processor_for_clip)
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
|
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
|
||||||
|
dummy_seq_data_for_pixtral_hf,
|
||||||
|
get_max_pixtral_hf_image_tokens,
|
||||||
|
input_processor_for_pixtral_hf)
|
||||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
||||||
input_processor_for_siglip)
|
input_processor_for_siglip)
|
||||||
@ -31,8 +36,13 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
|||||||
|
|
||||||
class LlavaImagePixelInputs(TypedDict):
|
class LlavaImagePixelInputs(TypedDict):
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
data: torch.Tensor
|
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
"""
|
||||||
|
Shape: `(batch_size * num_images, num_channels, height, width)`
|
||||||
|
|
||||||
|
Note that `height` or `width` may be different per batch and image,
|
||||||
|
in which case the data is passed as a list instead of a batched tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class LlavaImageEmbeddingInputs(TypedDict):
|
class LlavaImageEmbeddingInputs(TypedDict):
|
||||||
@ -77,6 +87,8 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
|||||||
num_image_tokens = get_max_clip_image_tokens(vision_config)
|
num_image_tokens = get_max_clip_image_tokens(vision_config)
|
||||||
elif isinstance(vision_config, SiglipVisionConfig):
|
elif isinstance(vision_config, SiglipVisionConfig):
|
||||||
num_image_tokens = get_max_siglip_image_tokens(vision_config)
|
num_image_tokens = get_max_siglip_image_tokens(vision_config)
|
||||||
|
elif isinstance(vision_config, PixtralVisionConfig):
|
||||||
|
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
|
||||||
else:
|
else:
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
@ -120,6 +132,17 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
|
|||||||
|
|
||||||
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
||||||
return seq_data, mm_data
|
return seq_data, mm_data
|
||||||
|
elif isinstance(vision_config, PixtralVisionConfig):
|
||||||
|
seq_data = dummy_seq_data_for_pixtral_hf(
|
||||||
|
vision_config,
|
||||||
|
seq_len,
|
||||||
|
num_images,
|
||||||
|
image_token_id=hf_config.image_token_index,
|
||||||
|
image_feature_size_override=image_feature_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||||
|
return seq_data, mm_data
|
||||||
|
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
@ -163,6 +186,15 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
|
|||||||
image_token_id=hf_config.image_token_index,
|
image_token_id=hf_config.image_token_index,
|
||||||
image_feature_size_override=image_feature_size,
|
image_feature_size_override=image_feature_size,
|
||||||
)
|
)
|
||||||
|
elif isinstance(vision_config, PixtralVisionConfig):
|
||||||
|
# We ignore image_feature_size_override since we have non-uniform
|
||||||
|
# image sizes for Pixtral
|
||||||
|
return input_processor_for_pixtral_hf(
|
||||||
|
model_config,
|
||||||
|
vision_config,
|
||||||
|
inputs,
|
||||||
|
image_token_id=hf_config.image_token_index,
|
||||||
|
)
|
||||||
|
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
@ -189,6 +221,9 @@ def _init_vision_tower(hf_config: LlavaConfig):
|
|||||||
vision_config,
|
vision_config,
|
||||||
num_hidden_layers_override=num_hidden_layers,
|
num_hidden_layers_override=num_hidden_layers,
|
||||||
)
|
)
|
||||||
|
elif isinstance(vision_config, PixtralVisionConfig):
|
||||||
|
# TODO: allow layer override?
|
||||||
|
return PixtralHFVisionModel(vision_config)
|
||||||
|
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
@ -210,6 +245,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
|
# NOTE: These are special cases for Pixtral-12B in the HF-format
|
||||||
|
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
|
||||||
|
if (config.text_config.architectures is None
|
||||||
|
and config.text_config.model_type == "mistral"):
|
||||||
|
config.text_config.architectures = ["MistralForCausalLM"]
|
||||||
|
if (config.projector_hidden_act is None
|
||||||
|
and config.vision_config.hidden_act == "gelu"):
|
||||||
|
config.projector_hidden_act = "gelu"
|
||||||
|
|
||||||
# TODO: Optionally initializes this for supporting embeddings.
|
# TODO: Optionally initializes this for supporting embeddings.
|
||||||
self.vision_tower = _init_vision_tower(config)
|
self.vision_tower = _init_vision_tower(config)
|
||||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||||
@ -246,6 +290,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
image_sizes = kwargs.pop("image_sizes", None)
|
||||||
image_embeds = kwargs.pop("image_embeds", None)
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
if pixel_values is None and image_embeds is None:
|
if pixel_values is None and image_embeds is None:
|
||||||
@ -256,6 +301,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
raise ValueError("Incorrect type of pixel values. "
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
# Case for models like PixtralHF that have dynamic image sizes
|
||||||
|
# so we need to produce a list of tensors
|
||||||
|
if image_sizes is not None:
|
||||||
|
images = pixel_values
|
||||||
|
if isinstance(images, torch.Tensor):
|
||||||
|
# if passed as batch take all images
|
||||||
|
NN, N, B, C, W, H = images.shape
|
||||||
|
images = images.reshape(NN * N * B, C, W, H)
|
||||||
|
images = [images[i] for i in range(images.size(0))]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# if passed as list flatten lists of tensors
|
||||||
|
while isinstance(images, list) and len(images) == 1:
|
||||||
|
images = images[0]
|
||||||
|
|
||||||
|
# TODO: Add validation based on image_sizes
|
||||||
|
return LlavaImagePixelInputs(
|
||||||
|
type="pixel_values",
|
||||||
|
data=images,
|
||||||
|
)
|
||||||
|
|
||||||
return LlavaImagePixelInputs(
|
return LlavaImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(
|
data=self._validate_pixel_values(
|
||||||
@ -286,7 +351,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
def _image_pixels_to_features(
|
def _image_pixels_to_features(
|
||||||
self,
|
self,
|
||||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
|
||||||
|
PixtralHFVisionModel],
|
||||||
pixel_values: torch.Tensor,
|
pixel_values: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
@ -3,18 +3,26 @@ from functools import cached_property
|
|||||||
from itertools import tee
|
from itertools import tee
|
||||||
from typing import Iterable, List, Mapping, Optional, Tuple, Union
|
from typing import Iterable, List, Mapping, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mistral_common.protocol.instruct.messages import ImageChunk
|
from mistral_common.protocol.instruct.messages import ImageChunk
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import PretrainedConfig
|
from transformers import PixtralVisionConfig, PretrainedConfig
|
||||||
|
from transformers.models.pixtral.image_processing_pixtral import (
|
||||||
|
_num_image_tokens)
|
||||||
|
from transformers.models.pixtral.modeling_pixtral import (
|
||||||
|
PixtralRotaryEmbedding, apply_rotary_pos_emb,
|
||||||
|
generate_block_attention_mask, position_ids_in_meshgrid)
|
||||||
from xformers.ops.fmha import memory_efficient_attention
|
from xformers.ops.fmha import memory_efficient_attention
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
|
||||||
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||||
|
token_inputs)
|
||||||
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
@ -25,6 +33,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
|||||||
from vllm.multimodal.base import MultiModalInputs
|
from vllm.multimodal.base import MultiModalInputs
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
from vllm.sequence import IntermediateTensors, SequenceData
|
from vllm.sequence import IntermediateTensors, SequenceData
|
||||||
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .utils import init_vllm_registered_model
|
from .utils import init_vllm_registered_model
|
||||||
@ -576,3 +586,397 @@ class VisionLanguageAdapter(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.w_out(self.gelu(self.w_in(x)))
|
return self.w_out(self.gelu(self.w_in(x)))
|
||||||
|
|
||||||
|
|
||||||
|
#### HF Transformers version of Pixtral ####
|
||||||
|
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
|
||||||
|
# This model follows the Llava family, meaning image embeddings are placed
|
||||||
|
# instead of the `[IMG]` token placeholders.
|
||||||
|
# The model uses [`PixtralVisionModel`] for its vision encoder,
|
||||||
|
# and [`MistralForCausalLM`] for its language decoder.
|
||||||
|
|
||||||
|
|
||||||
|
def get_pixtral_hf_patch_grid_length(*, image_size: int,
|
||||||
|
patch_size: int) -> int:
|
||||||
|
# Since interpolation is applied, the image size need not be divisible
|
||||||
|
# assert image_size % patch_size == 0
|
||||||
|
return image_size // patch_size
|
||||||
|
|
||||||
|
|
||||||
|
def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int:
|
||||||
|
grid_length = get_pixtral_hf_patch_grid_length(image_size=image_size,
|
||||||
|
patch_size=patch_size)
|
||||||
|
return grid_length * grid_length
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_pixtral_hf_image_feature_size(
|
||||||
|
hf_config: PixtralVisionConfig) -> int:
|
||||||
|
return get_pixtral_hf_num_patches(image_size=hf_config.image_size,
|
||||||
|
patch_size=hf_config.patch_size)
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
|
||||||
|
return get_max_pixtral_hf_image_feature_size(hf_config)
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_seq_data_for_pixtral_hf(
|
||||||
|
hf_config: PixtralVisionConfig,
|
||||||
|
seq_len: int,
|
||||||
|
num_images: int,
|
||||||
|
*,
|
||||||
|
image_token_id: int,
|
||||||
|
image_feature_size_override: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if image_feature_size_override is None:
|
||||||
|
image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config)
|
||||||
|
else:
|
||||||
|
image_feature_size = image_feature_size_override
|
||||||
|
|
||||||
|
return SequenceData.from_prompt_token_counts(
|
||||||
|
(image_token_id, image_feature_size * num_images),
|
||||||
|
(0, seq_len - image_feature_size * num_images),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_image_for_pixtral_hf(
|
||||||
|
hf_config: PixtralVisionConfig,
|
||||||
|
num_images: int,
|
||||||
|
*,
|
||||||
|
image_width_override: Optional[int] = None,
|
||||||
|
image_height_override: Optional[int] = None,
|
||||||
|
):
|
||||||
|
width = height = hf_config.image_size
|
||||||
|
if image_width_override is not None:
|
||||||
|
width = image_width_override
|
||||||
|
if image_height_override is not None:
|
||||||
|
height = image_height_override
|
||||||
|
|
||||||
|
image = Image.new("RGB", (width, height), color=0)
|
||||||
|
return {"image": image if num_images == 1 else [image] * num_images}
|
||||||
|
|
||||||
|
|
||||||
|
def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int) -> Tuple[int, int]:
|
||||||
|
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
|
||||||
|
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501
|
||||||
|
max_width, max_height = hf_config.image_size, hf_config.image_size
|
||||||
|
patch_width, patch_height = hf_config.patch_size, hf_config.patch_size
|
||||||
|
|
||||||
|
ratio = max(image_width / max_width, image_height / max_height)
|
||||||
|
|
||||||
|
if ratio > 1:
|
||||||
|
image_width = int(numpy.ceil(image_width / ratio))
|
||||||
|
image_height = int(numpy.ceil(image_height / ratio))
|
||||||
|
|
||||||
|
num_height_tokens, num_width_tokens = _num_image_tokens(
|
||||||
|
(image_height, image_width), (patch_height, patch_width))
|
||||||
|
|
||||||
|
return num_width_tokens, num_height_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def input_processor_for_pixtral_hf(
|
||||||
|
model_config: ModelConfig,
|
||||||
|
hf_config: PixtralVisionConfig,
|
||||||
|
inputs: DecoderOnlyInputs,
|
||||||
|
*,
|
||||||
|
image_token_id: int,
|
||||||
|
image_feature_size_override: Optional[Union[int, List[int]]] = None,
|
||||||
|
) -> DecoderOnlyInputs:
|
||||||
|
assert image_feature_size_override is None, (
|
||||||
|
"image_feature_size_override is not supported for Pixtral")
|
||||||
|
|
||||||
|
multi_modal_data = inputs.get("multi_modal_data")
|
||||||
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
processor = cached_get_processor(model_config.model)
|
||||||
|
|
||||||
|
image_data = multi_modal_data["image"]
|
||||||
|
if isinstance(image_data, Image.Image):
|
||||||
|
image_data = [image_data]
|
||||||
|
elif not is_list_of(image_data, Image.Image):
|
||||||
|
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||||
|
|
||||||
|
new_prompt = inputs.get("prompt")
|
||||||
|
new_token_ids = inputs["prompt_token_ids"]
|
||||||
|
|
||||||
|
# Update new_prompt if present
|
||||||
|
if new_prompt:
|
||||||
|
replace_strings = []
|
||||||
|
for image in image_data:
|
||||||
|
w, h = image.size
|
||||||
|
|
||||||
|
(num_width_tokens,
|
||||||
|
num_height_tokens) = get_pixtral_hf_image_feature_size(
|
||||||
|
hf_config, image_width=w, image_height=h)
|
||||||
|
|
||||||
|
replace_tokens = [[processor.image_token] * num_width_tokens +
|
||||||
|
[processor.image_break_token]
|
||||||
|
] * num_height_tokens
|
||||||
|
# Flatten list
|
||||||
|
replace_tokens = [
|
||||||
|
item for sublist in replace_tokens for item in sublist
|
||||||
|
]
|
||||||
|
replace_tokens[-1] = processor.image_end_token
|
||||||
|
replace_str = "".join(replace_tokens)
|
||||||
|
replace_strings.append(replace_str)
|
||||||
|
new_prompt = new_prompt.replace(processor.image_token,
|
||||||
|
"<placeholder>", 1)
|
||||||
|
|
||||||
|
while "<placeholder>" in new_prompt:
|
||||||
|
replace_str = replace_strings.pop(0)
|
||||||
|
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
|
||||||
|
|
||||||
|
# Update new_token_ids
|
||||||
|
image_token_id = 10
|
||||||
|
image_break_id = 12
|
||||||
|
image_end_id = 13
|
||||||
|
placeholder_token_id = -999
|
||||||
|
replace_tokens_list = []
|
||||||
|
for image in image_data:
|
||||||
|
w, h = image.size
|
||||||
|
|
||||||
|
num_width_tokens, num_height_tokens = get_pixtral_hf_image_feature_size(
|
||||||
|
hf_config, image_width=w, image_height=h)
|
||||||
|
|
||||||
|
replace_tokens = [[image_token_id] * num_width_tokens +
|
||||||
|
[image_break_id]] * num_height_tokens
|
||||||
|
# Flatten list
|
||||||
|
replace_tokens = [
|
||||||
|
item for sublist in replace_tokens for item in sublist
|
||||||
|
]
|
||||||
|
replace_tokens[-1] = image_end_id
|
||||||
|
replace_tokens_list.append(replace_tokens)
|
||||||
|
# Replace image id with placeholder id
|
||||||
|
next_image_index = new_token_ids.index(image_token_id)
|
||||||
|
new_token_ids[next_image_index] = placeholder_token_id
|
||||||
|
|
||||||
|
while placeholder_token_id in new_token_ids:
|
||||||
|
replace_tokens = replace_tokens_list.pop(0)
|
||||||
|
next_image_index = new_token_ids.index(placeholder_token_id)
|
||||||
|
prefix = new_token_ids[:next_image_index]
|
||||||
|
postfix = new_token_ids[next_image_index + 1:]
|
||||||
|
new_token_ids = prefix + replace_tokens + postfix
|
||||||
|
|
||||||
|
# NOTE: Create a defensive copy of the original inputs
|
||||||
|
return token_inputs(prompt_token_ids=new_token_ids,
|
||||||
|
prompt=new_prompt,
|
||||||
|
multi_modal_data=multi_modal_data)
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralHFMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: PixtralVisionConfig):
|
||||||
|
super().__init__()
|
||||||
|
assert config.intermediate_size is not None
|
||||||
|
self.gate_proj = nn.Linear(config.hidden_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
bias=False)
|
||||||
|
self.up_proj = nn.Linear(config.hidden_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
bias=False)
|
||||||
|
self.down_proj = nn.Linear(config.intermediate_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=False)
|
||||||
|
self.act = get_act_fn(config.hidden_act)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralHFAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: PixtralVisionConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
assert not config.hidden_size % config.num_attention_heads
|
||||||
|
self.n_heads = config.num_attention_heads
|
||||||
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=False)
|
||||||
|
self.k_proj = nn.Linear(config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=False)
|
||||||
|
self.v_proj = nn.Linear(config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=False)
|
||||||
|
self.o_proj = nn.Linear(config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
position_embeddings: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
batch_size, patches, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(batch_size, patches, self.n_heads,
|
||||||
|
self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(batch_size, patches, self.n_heads,
|
||||||
|
self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(batch_size, patches, self.n_heads,
|
||||||
|
self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states,
|
||||||
|
key_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
unsqueeze_dim=0)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(
|
||||||
|
2, 3)) * self.scale
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights,
|
||||||
|
dim=-1,
|
||||||
|
dtype=torch.float32).to(
|
||||||
|
query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(batch_size, patches, -1)
|
||||||
|
|
||||||
|
return self.o_proj(attn_output)
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralHFTransformerBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: PixtralVisionConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
self.attention = PixtralHFAttention(config)
|
||||||
|
self.feed_forward = PixtralHFMLP(config)
|
||||||
|
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
position_embeddings: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r = self.attention.forward(self.attention_norm(hidden_states),
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_embeddings=position_embeddings)
|
||||||
|
h = hidden_states + r
|
||||||
|
r = self.feed_forward.forward(self.ffn_norm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralHFTransformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: PixtralVisionConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = torch.nn.ModuleList()
|
||||||
|
for _ in range(config.num_hidden_layers):
|
||||||
|
self.layers.append(PixtralHFTransformerBlock(config))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
position_embeddings: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, attention_mask, position_embeddings)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralHFVisionModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: PixtralVisionConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.patch_conv = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=config.hidden_size,
|
||||||
|
kernel_size=config.patch_size,
|
||||||
|
stride=config.patch_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
self.transformer = PixtralHFTransformer(config)
|
||||||
|
self.dtype = next(self.parameters()).dtype
|
||||||
|
self.device = next(self.parameters()).device
|
||||||
|
self.patch_positional_embedding = PixtralRotaryEmbedding(
|
||||||
|
config, self.device)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: List[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pixel_values: tensor of token features for
|
||||||
|
all tokens of all images of shape (N_toks, D)
|
||||||
|
Returns:
|
||||||
|
image_features: tensor of token features for
|
||||||
|
all tokens of all images of shape (N_toks, D)
|
||||||
|
"""
|
||||||
|
# pass images through initial convolution independently
|
||||||
|
patch_embeds_list = [
|
||||||
|
self.patch_conv(
|
||||||
|
img.reshape(-1, img.shape[-3], img.shape[-2],
|
||||||
|
img.shape[-1]).to(self.dtype))
|
||||||
|
for img in pixel_values
|
||||||
|
]
|
||||||
|
|
||||||
|
# flatten to a single sequence
|
||||||
|
patch_embeds = torch.cat(
|
||||||
|
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
|
||||||
|
patch_embeds = self.ln_pre(patch_embeds)
|
||||||
|
|
||||||
|
# positional embeddings
|
||||||
|
position_ids = position_ids_in_meshgrid(
|
||||||
|
patch_embeds_list,
|
||||||
|
max_width=self.config.image_size // self.config.patch_size).to(
|
||||||
|
self.device)
|
||||||
|
|
||||||
|
position_embedding = self.patch_positional_embedding(
|
||||||
|
patch_embeds, position_ids)
|
||||||
|
attention_mask = generate_block_attention_mask(
|
||||||
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
||||||
|
patch_embeds)
|
||||||
|
out = self.transformer(patch_embeds, attention_mask,
|
||||||
|
position_embedding)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
# (TODO) Add prefix argument for filtering out weights to be loaded
|
||||||
|
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = []
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name.replace(weight_name, param_name)]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||||
from functools import lru_cache, partial
|
from functools import partial
|
||||||
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
|
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
|
||||||
Tuple, Type, TypedDict, Union)
|
Tuple, Type, TypedDict, Union)
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ from vllm.multimodal.base import MultiModalData
|
|||||||
from vllm.multimodal.image import cached_get_image_processor
|
from vllm.multimodal.image import cached_get_image_processor
|
||||||
from vllm.sequence import IntermediateTensors, SequenceData
|
from vllm.sequence import IntermediateTensors, SequenceData
|
||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
from vllm.transformers_utils.processor import get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .utils import (PPMissingLayer, get_vit_attn_backend,
|
from .utils import (PPMissingLayer, get_vit_attn_backend,
|
||||||
@ -544,8 +544,6 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
# === Vision input helpers === #
|
# === Vision input helpers === #
|
||||||
|
|
||||||
cached_get_processor = lru_cache(get_processor)
|
|
||||||
|
|
||||||
|
|
||||||
def mm_input_mapper_for_qwen2_vl(
|
def mm_input_mapper_for_qwen2_vl(
|
||||||
ctx: InputContext,
|
ctx: InputContext,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from functools import lru_cache
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
|
||||||
@ -37,6 +38,9 @@ def get_processor(
|
|||||||
return cast(ProcessorMixin, processor)
|
return cast(ProcessorMixin, processor)
|
||||||
|
|
||||||
|
|
||||||
|
cached_get_processor = lru_cache(get_processor)
|
||||||
|
|
||||||
|
|
||||||
def get_image_processor(
|
def get_image_processor(
|
||||||
processor_name: str,
|
processor_name: str,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user