[Model] Support Pixtral models in the HF Transformers format (#9036)

This commit is contained in:
Michael Goin 2024-10-18 15:29:56 -04:00 committed by GitHub
parent 67a7e5ef38
commit 3921a2f29e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 503 additions and 12 deletions

View File

@ -437,7 +437,7 @@ Text Generation
* - :code:`PixtralForConditionalGeneration`
- Pixtral
- T + I\ :sup:`+`
- :code:`mistralai/Pixtral-12B-2409`
- :code:`mistralai/Pixtral-12B-2409`, :code:`mistral-community/pixtral-12b` etc.
-
- ✅︎
* - :code:`QWenLMHeadModel`

View File

@ -277,6 +277,22 @@ def run_qwen2_vl(question: str, modality: str):
return llm, prompt, stop_token_ids
# Pixtral HF-format
def run_pixtral_hf(question: str, modality: str):
assert modality == "image"
model_name = "mistral-community/pixtral-12b"
llm = LLM(
model=model_name,
max_model_len=8192,
)
prompt = f"<s>[INST]{question}\n[IMG][/INST]"
stop_token_ids = None
return llm, prompt, stop_token_ids
# LLama 3.2
def run_mllama(question: str, modality: str):
assert modality == "image"
@ -347,6 +363,7 @@ model_example_map = {
"NVLM_D": run_nvlm_d,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"pixtral_hf": run_pixtral_hf,
"mllama": run_mllama,
"molmo": run_molmo,
"glm4v": run_glm4v,

View File

@ -264,6 +264,8 @@ _ACTIVATION_REGISTRY = LazyDict({
lambda: nn.ReLU(),
"relu2":
lambda: ReLUSquaredActivation(),
"silu":
lambda: nn.SiLU(),
"quick_gelu":
lambda: QuickGELU(),
})

View File

@ -5,7 +5,8 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
SiglipVisionConfig)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
@ -22,6 +23,10 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
dummy_seq_data_for_pixtral_hf,
get_max_pixtral_hf_image_tokens,
input_processor_for_pixtral_hf)
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
@ -31,8 +36,13 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
class LlavaImageEmbeddingInputs(TypedDict):
@ -77,6 +87,8 @@ def get_max_llava_image_tokens(ctx: InputContext):
num_image_tokens = get_max_clip_image_tokens(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_image_tokens = get_max_siglip_image_tokens(vision_config)
elif isinstance(vision_config, PixtralVisionConfig):
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@ -120,6 +132,17 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data
elif isinstance(vision_config, PixtralVisionConfig):
seq_data = dummy_seq_data_for_pixtral_hf(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@ -163,6 +186,15 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, PixtralVisionConfig):
# We ignore image_feature_size_override since we have non-uniform
# image sizes for Pixtral
return input_processor_for_pixtral_hf(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@ -189,6 +221,9 @@ def _init_vision_tower(hf_config: LlavaConfig):
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, PixtralVisionConfig):
# TODO: allow layer override?
return PixtralHFVisionModel(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@ -210,6 +245,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.config = config
self.multimodal_config = multimodal_config
# NOTE: These are special cases for Pixtral-12B in the HF-format
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
if (config.text_config.architectures is None
and config.text_config.model_type == "mistral"):
config.text_config.architectures = ["MistralForCausalLM"]
if (config.projector_hidden_act is None
and config.vision_config.hidden_act == "gelu"):
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config)
self.multi_modal_projector = LlavaMultiModalProjector(
@ -246,6 +290,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
@ -256,6 +301,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Case for models like PixtralHF that have dynamic image sizes
# so we need to produce a list of tensors
if image_sizes is not None:
images = pixel_values
if isinstance(images, torch.Tensor):
# if passed as batch take all images
NN, N, B, C, W, H = images.shape
images = images.reshape(NN * N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# if passed as list flatten lists of tensors
while isinstance(images, list) and len(images) == 1:
images = images[0]
# TODO: Add validation based on image_sizes
return LlavaImagePixelInputs(
type="pixel_values",
data=images,
)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
@ -286,7 +351,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:

View File

@ -3,18 +3,26 @@ from functools import cached_property
from itertools import tee
from typing import Iterable, List, Mapping, Optional, Tuple, Union
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk
from PIL import Image
from transformers import PretrainedConfig
from transformers import PixtralVisionConfig, PretrainedConfig
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens)
from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding, apply_rotary_pos_emb,
generate_block_attention_mask, position_ids_in_meshgrid)
from xformers.ops.fmha import memory_efficient_attention
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
@ -25,6 +33,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import init_vllm_registered_model
@ -576,3 +586,397 @@ class VisionLanguageAdapter(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x)))
#### HF Transformers version of Pixtral ####
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
# This model follows the Llava family, meaning image embeddings are placed
# instead of the `[IMG]` token placeholders.
# The model uses [`PixtralVisionModel`] for its vision encoder,
# and [`MistralForCausalLM`] for its language decoder.
def get_pixtral_hf_patch_grid_length(*, image_size: int,
patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible
# assert image_size % patch_size == 0
return image_size // patch_size
def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int:
grid_length = get_pixtral_hf_patch_grid_length(image_size=image_size,
patch_size=patch_size)
return grid_length * grid_length
def get_max_pixtral_hf_image_feature_size(
hf_config: PixtralVisionConfig) -> int:
return get_pixtral_hf_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
return get_max_pixtral_hf_image_feature_size(hf_config)
def dummy_seq_data_for_pixtral_hf(
hf_config: PixtralVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
def dummy_image_for_pixtral_hf(
hf_config: PixtralVisionConfig,
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
image_width: int,
image_height: int) -> Tuple[int, int]:
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501
max_width, max_height = hf_config.image_size, hf_config.image_size
patch_width, patch_height = hf_config.patch_size, hf_config.patch_size
ratio = max(image_width / max_width, image_height / max_height)
if ratio > 1:
image_width = int(numpy.ceil(image_width / ratio))
image_height = int(numpy.ceil(image_height / ratio))
num_height_tokens, num_width_tokens = _num_image_tokens(
(image_height, image_width), (patch_height, patch_width))
return num_width_tokens, num_height_tokens
def input_processor_for_pixtral_hf(
model_config: ModelConfig,
hf_config: PixtralVisionConfig,
inputs: DecoderOnlyInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[Union[int, List[int]]] = None,
) -> DecoderOnlyInputs:
assert image_feature_size_override is None, (
"image_feature_size_override is not supported for Pixtral")
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
processor = cached_get_processor(model_config.model)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_data = [image_data]
elif not is_list_of(image_data, Image.Image):
raise TypeError(f"Invalid image type: {type(image_data)}")
new_prompt = inputs.get("prompt")
new_token_ids = inputs["prompt_token_ids"]
# Update new_prompt if present
if new_prompt:
replace_strings = []
for image in image_data:
w, h = image.size
(num_width_tokens,
num_height_tokens) = get_pixtral_hf_image_feature_size(
hf_config, image_width=w, image_height=h)
replace_tokens = [[processor.image_token] * num_width_tokens +
[processor.image_break_token]
] * num_height_tokens
# Flatten list
replace_tokens = [
item for sublist in replace_tokens for item in sublist
]
replace_tokens[-1] = processor.image_end_token
replace_str = "".join(replace_tokens)
replace_strings.append(replace_str)
new_prompt = new_prompt.replace(processor.image_token,
"<placeholder>", 1)
while "<placeholder>" in new_prompt:
replace_str = replace_strings.pop(0)
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
# Update new_token_ids
image_token_id = 10
image_break_id = 12
image_end_id = 13
placeholder_token_id = -999
replace_tokens_list = []
for image in image_data:
w, h = image.size
num_width_tokens, num_height_tokens = get_pixtral_hf_image_feature_size(
hf_config, image_width=w, image_height=h)
replace_tokens = [[image_token_id] * num_width_tokens +
[image_break_id]] * num_height_tokens
# Flatten list
replace_tokens = [
item for sublist in replace_tokens for item in sublist
]
replace_tokens[-1] = image_end_id
replace_tokens_list.append(replace_tokens)
# Replace image id with placeholder id
next_image_index = new_token_ids.index(image_token_id)
new_token_ids[next_image_index] = placeholder_token_id
while placeholder_token_id in new_token_ids:
replace_tokens = replace_tokens_list.pop(0)
next_image_index = new_token_ids.index(placeholder_token_id)
prefix = new_token_ids[:next_image_index]
postfix = new_token_ids[next_image_index + 1:]
new_token_ids = prefix + replace_tokens + postfix
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
class PixtralHFMLP(nn.Module):
def __init__(self, config: PixtralVisionConfig):
super().__init__()
assert config.intermediate_size is not None
self.gate_proj = nn.Linear(config.hidden_size,
config.intermediate_size,
bias=False)
self.up_proj = nn.Linear(config.hidden_size,
config.intermediate_size,
bias=False)
self.down_proj = nn.Linear(config.intermediate_size,
config.hidden_size,
bias=False)
self.act = get_act_fn(config.hidden_act)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
class PixtralHFAttention(nn.Module):
def __init__(self, config: PixtralVisionConfig):
super().__init__()
self.config = config
assert not config.hidden_size % config.num_attention_heads
self.n_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.k_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.v_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.o_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, patches, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, patches, self.n_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, patches, self.n_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, patches, self.n_heads,
self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states,
key_states,
cos,
sin,
unsqueeze_dim=0)
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) * self.scale
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, patches, -1)
return self.o_proj(attn_output)
class PixtralHFTransformerBlock(nn.Module):
def __init__(self, config: PixtralVisionConfig):
super().__init__()
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
self.attention = PixtralHFAttention(config)
self.feed_forward = PixtralHFMLP(config)
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(hidden_states),
attention_mask=attention_mask,
position_embeddings=position_embeddings)
h = hidden_states + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
return out
class PixtralHFTransformer(nn.Module):
def __init__(self, config: PixtralVisionConfig):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(config.num_hidden_layers):
self.layers.append(PixtralHFTransformerBlock(config))
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, attention_mask, position_embeddings)
return x
class PixtralHFVisionModel(nn.Module):
def __init__(self, config: PixtralVisionConfig):
super().__init__()
self.config = config
self.patch_conv = nn.Conv2d(
in_channels=config.num_channels,
out_channels=config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
bias=False,
)
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralHFTransformer(config)
self.dtype = next(self.parameters()).dtype
self.device = next(self.parameters()).device
self.patch_positional_embedding = PixtralRotaryEmbedding(
config, self.device)
def forward(
self,
pixel_values: List[torch.Tensor],
) -> torch.Tensor:
"""
Args:
pixel_values: tensor of token features for
all tokens of all images of shape (N_toks, D)
Returns:
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [
self.patch_conv(
img.reshape(-1, img.shape[-3], img.shape[-2],
img.shape[-1]).to(self.dtype))
for img in pixel_values
]
# flatten to a single sequence
patch_embeds = torch.cat(
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
position_ids = position_ids_in_meshgrid(
patch_embeds_list,
max_width=self.config.image_size // self.config.patch_size).to(
self.device)
position_embedding = self.patch_positional_embedding(
patch_embeds, position_ids)
attention_mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
patch_embeds)
out = self.transformer(patch_embeds, attention_mask,
position_embedding)
return out
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = []
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import lru_cache, partial
from functools import partial
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Tuple, Type, TypedDict, Union)
@ -63,7 +63,7 @@ from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import get_processor
from vllm.transformers_utils.processor import cached_get_processor
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (PPMissingLayer, get_vit_attn_backend,
@ -544,8 +544,6 @@ class Qwen2VisionTransformer(nn.Module):
# === Vision input helpers === #
cached_get_processor = lru_cache(get_processor)
def mm_input_mapper_for_qwen2_vl(
ctx: InputContext,

View File

@ -1,3 +1,4 @@
from functools import lru_cache
from typing import Any, cast
@ -37,6 +38,9 @@ def get_processor(
return cast(ProcessorMixin, processor)
cached_get_processor = lru_cache(get_processor)
def get_image_processor(
processor_name: str,
*args: Any,