1123 lines
40 KiB
Python
1123 lines
40 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import math
|
|
from dataclasses import dataclass, fields
|
|
from functools import cached_property
|
|
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mistral_common.protocol.instruct.messages import ImageChunk
|
|
from PIL import Image
|
|
from transformers import PixtralVisionConfig
|
|
from transformers.models.pixtral.image_processing_pixtral import (
|
|
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
|
|
from transformers.models.pixtral.modeling_pixtral import (
|
|
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
|
InputContext, token_inputs)
|
|
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
|
from vllm.multimodal.inputs import PlaceholderRange
|
|
from vllm.multimodal.utils import consecutive_placeholder_ranges
|
|
from vllm.sequence import IntermediateTensors, SequenceData
|
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
|
|
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
|
from .utils import (init_vllm_registered_model, maybe_prefix,
|
|
merge_multimodal_embeddings)
|
|
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
|
|
|
try:
|
|
from xformers import ops as xops
|
|
USE_XFORMERS_OPS = True
|
|
except ImportError:
|
|
USE_XFORMERS_OPS = False
|
|
|
|
|
|
def get_max_pixtral_image_tokens(ctx: InputContext):
|
|
tokenizer = cached_tokenizer_from_config(ctx.model_config)
|
|
mm_encoder = tokenizer.instruct.mm_encoder
|
|
|
|
image_config = mm_encoder.mm_config if hasattr(
|
|
mm_encoder, "mm_config") else mm_encoder.image_config
|
|
|
|
max_image_size = image_config.max_image_size
|
|
image_patch_size = image_config.image_patch_size
|
|
|
|
return ((max_image_size // image_patch_size)**2)
|
|
|
|
|
|
def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
|
|
mm_counts: Mapping[str, int]):
|
|
tokenizer = cached_tokenizer_from_config(ctx.model_config)
|
|
|
|
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
|
image_token_id = mm_encoder.special_ids.img
|
|
|
|
mm_config = ctx.get_mm_config()
|
|
num_images = mm_config.get_limit_per_prompt("image")
|
|
|
|
# dummy size
|
|
size = 256
|
|
image = Image.new("RGB", (size, size), color=0)
|
|
|
|
encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image))
|
|
image_feature_size = len(encoding.tokens)
|
|
num_image_tokens = image_feature_size * num_images
|
|
seq_data = SequenceData.from_prompt_token_counts(
|
|
(image_token_id, num_image_tokens),
|
|
(0, seq_len - num_image_tokens),
|
|
)
|
|
|
|
mm_data = {"image": num_images * [image]}
|
|
mm_placeholders = {
|
|
"image":
|
|
consecutive_placeholder_ranges(num_items=num_images,
|
|
item_size=image_feature_size)
|
|
}
|
|
return DummyData(seq_data, mm_data, mm_placeholders)
|
|
|
|
|
|
def input_mapper_for_pixtral(ctx: InputContext,
|
|
data: object) -> MultiModalKwargs:
|
|
"""Maps the input data to its MultiModalKwargs (if any).
|
|
|
|
Args:
|
|
ctx: Context of the loaded model.
|
|
data: data potentially containing PIL images to be processed
|
|
and mapped to `images`.
|
|
|
|
Returns:
|
|
MultiModalKwargs containing the stacked normalized images tensor or
|
|
image embeddings.
|
|
"""
|
|
tokenizer = cached_tokenizer_from_config(ctx.model_config)
|
|
|
|
data_list = data if isinstance(data, list) else [data]
|
|
|
|
images = []
|
|
image_tokens_list = []
|
|
for image_data in data_list:
|
|
image = ImageChunk(image=image_data)
|
|
encoding = tokenizer.instruct.mm_encoder(image)
|
|
image = torch.from_numpy(encoding.image).to(dtype=torch.float16)
|
|
images.append(image)
|
|
image_tokens_list.append(encoding.tokens)
|
|
|
|
image_tokens = torch.tensor([
|
|
token_id for image_tokens in image_tokens_list
|
|
for token_id in image_tokens
|
|
])
|
|
return MultiModalKwargs({"images": images, "image_tokens": image_tokens})
|
|
|
|
|
|
def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
|
|
multi_modal_data = inputs.get("multi_modal_data")
|
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
|
return inputs
|
|
|
|
prompt_token_ids = inputs.get("prompt_token_ids")
|
|
prompt = inputs.get("prompt")
|
|
tokenizer = cached_tokenizer_from_config(ctx.model_config)
|
|
|
|
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
|
image_token_id = mm_encoder.special_ids.img
|
|
image_break_id = mm_encoder.special_ids.img_break
|
|
image_end_id = mm_encoder.special_ids.img_end
|
|
|
|
if image_token_id not in inputs['prompt_token_ids']:
|
|
raise ValueError(
|
|
f"You've passed {inputs=} without {image_token_id=}"
|
|
" Make sure to process your input via mistral_common's"
|
|
" tokenizer or pass a chat completion request. For more"
|
|
" For more info, see: "
|
|
"https://github.com/vllm-project/vllm/issues/8411.")
|
|
|
|
# Get precise tracking of placeholder positions
|
|
placeholder_ranges = []
|
|
curr_offset = -1
|
|
curr_length = 0
|
|
for i in range(len(prompt_token_ids)):
|
|
if prompt_token_ids[i] in (image_token_id, image_break_id):
|
|
if curr_offset < 0:
|
|
curr_offset = i
|
|
curr_length += 1
|
|
elif prompt_token_ids[i] == image_end_id:
|
|
curr_length += 1
|
|
placeholder_ranges.append(
|
|
PlaceholderRange(offset=curr_offset, length=curr_length))
|
|
curr_offset = -1
|
|
curr_length = 0
|
|
else:
|
|
pass
|
|
return token_inputs(prompt=prompt,
|
|
prompt_token_ids=prompt_token_ids,
|
|
multi_modal_data=multi_modal_data,
|
|
multi_modal_placeholders={"image": placeholder_ranges})
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
|
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
|
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
|
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
|
|
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|
SupportsPP):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
|
|
dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
|
|
vision_args = {
|
|
key: value
|
|
for key, value in self.config.vision_config.to_dict().items()
|
|
if key in dataclass_fields
|
|
}
|
|
|
|
if not ("image_break_token_id" in vision_args
|
|
and "image_end_token_id" in vision_args):
|
|
raise ValueError(
|
|
"'image_break_token_id' and 'image_end_token_id' not found "
|
|
"in the vision_encoder arguments. Please download the latest "
|
|
"version of 'params.json' from the model repository.")
|
|
|
|
self.vision_args = VisionEncoderArgs(**vision_args)
|
|
|
|
# init MistralForCausalLM
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=config.text_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
)
|
|
|
|
self.vision_encoder = VisionTransformer(self.vision_args)
|
|
self.vision_language_adapter = VisionLanguageAdapter(
|
|
self.vision_args, dim=config.text_config.hidden_size)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors)
|
|
|
|
@cached_property
|
|
def sampler(self):
|
|
if hasattr(self.language_model, "sampler"):
|
|
return self.language_model.sampler
|
|
|
|
return get_sampler()
|
|
|
|
def get_multimodal_embeddings(
|
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
|
image_input, image_tokens = self._parse_and_validate_image_input(
|
|
**kwargs)
|
|
if image_input is None:
|
|
return None
|
|
|
|
vision_embeddings = self._process_image_input(image_input)
|
|
|
|
# NOTE: We patch the outputs of the vision encoder with embeddings
|
|
# from `[IMG_BREAK]` and `[IMG_END]` tokens.
|
|
image_embeds = self.language_model.get_input_embeddings(image_tokens)
|
|
image_token_mask = image_tokens == self.vision_args.image_token_id
|
|
image_embeds[image_token_mask] = vision_embeddings
|
|
|
|
# NOTE: Image embeddings are split into separate tensors for each image
|
|
# by the indices of `[IMG_END]` token.
|
|
image_end_mask = image_tokens == self.vision_args.image_end_token_id
|
|
split_indices = torch.where(image_end_mask)[0] + 1
|
|
if len(split_indices) <= 1:
|
|
# Do not split, return as tensor of shape [1, fs, hs]
|
|
return image_embeds.unsqueeze(0)
|
|
|
|
# If the last split index is the last index in image_tokens, we
|
|
# ignore it to avoid empty split tensor
|
|
if split_indices[-1] == len(image_tokens):
|
|
split_indices = split_indices[:-1]
|
|
|
|
image_embeds = image_embeds.tensor_split(split_indices.cpu())
|
|
return image_embeds
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
if multimodal_embeddings is not None:
|
|
inputs_embeds = merge_multimodal_embeddings(
|
|
input_ids, inputs_embeds, multimodal_embeddings, [
|
|
self.vision_args.image_token_id,
|
|
self.vision_args.image_break_token_id,
|
|
self.vision_args.image_end_token_id,
|
|
])
|
|
return inputs_embeds
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
**kwargs: object,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
"""Run forward pass for pixtral.
|
|
"""
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
|
# condition is for v0 compatibility.
|
|
elif inputs_embeds is None:
|
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
|
vision_embeddings)
|
|
input_ids = None
|
|
|
|
hidden_states = self.language_model.model(input_ids,
|
|
positions,
|
|
intermediate_tensors,
|
|
inputs_embeds=inputs_embeds)
|
|
|
|
return hidden_states
|
|
|
|
def _parse_and_validate_image_input(
|
|
self,
|
|
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
|
|
torch.Tensor]] = None,
|
|
image_tokens: Optional[torch.Tensor] = None,
|
|
) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]:
|
|
if images is None:
|
|
return None, None
|
|
|
|
if isinstance(images, torch.Tensor):
|
|
# if passed as batch take all images
|
|
N, B, C, W, H = images.shape
|
|
images = images.reshape(N * B, C, W, H)
|
|
images = [images[i] for i in range(images.size(0))]
|
|
elif isinstance(images, list):
|
|
# if passed as list flatten lists of tensors
|
|
flatten_images = []
|
|
for imgs_per_req in images:
|
|
imgs_per_req = [
|
|
imgs_per_req[i] for i in range(imgs_per_req.size(0))
|
|
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
|
|
|
|
flatten_images.extend(imgs_per_req)
|
|
|
|
images = flatten_images
|
|
|
|
if isinstance(image_tokens, torch.Tensor):
|
|
# image_tokens are batched
|
|
image_tokens = image_tokens.flatten()
|
|
elif isinstance(image_tokens, list):
|
|
# image_tokens are of different lengths thus passed as a list
|
|
image_tokens = torch.cat(image_tokens)
|
|
|
|
assert image_tokens.dim() == 1
|
|
|
|
return images, image_tokens
|
|
|
|
def _process_image_input(self,
|
|
image_input: List[torch.Tensor]) -> torch.Tensor:
|
|
return self.vision_language_adapter(self.vision_encoder(image_input))
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
return self.language_model.compute_logits(hidden_states,
|
|
sampling_metadata)
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
return self.language_model.sample(logits, sampling_metadata)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
|
|
def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
|
|
return weight[0].startswith("vision_encoder")
|
|
|
|
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
|
|
return weight[0].startswith("vision_language_adapter")
|
|
|
|
# Get references to parameters for direct loading
|
|
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
|
|
vision_lang_adapter_dict = dict(
|
|
self.vision_language_adapter.named_parameters())
|
|
|
|
def llm_weights_generator():
|
|
# Single pass over weights
|
|
for name, w in weights:
|
|
if is_vision_encoder_weights((name, w)):
|
|
# Load vision encoder weights directly
|
|
trimmed_name = '.'.join(name.split(".")[1:])
|
|
param = vision_encoder_dict[trimmed_name]
|
|
with torch.no_grad():
|
|
default_weight_loader(param, w)
|
|
elif is_vision_lang_adapter_weights((name, w)):
|
|
# Load vision-language adapter weights directly
|
|
trimmed_name = '.'.join(name.split(".")[1:])
|
|
param = vision_lang_adapter_dict[trimmed_name]
|
|
with torch.no_grad():
|
|
default_weight_loader(param, w)
|
|
else:
|
|
# LLM weights: yield them to be loaded
|
|
# by language_model.load_weights
|
|
yield (name, w)
|
|
|
|
# Now we call the language model load with the generator
|
|
self.language_model.load_weights(llm_weights_generator())
|
|
|
|
|
|
# Vision encoder
|
|
@dataclass
|
|
class VisionEncoderArgs:
|
|
hidden_size: int
|
|
num_channels: int
|
|
image_size: int
|
|
patch_size: int
|
|
intermediate_size: int
|
|
num_hidden_layers: int
|
|
num_attention_heads: int
|
|
rope_theta: float # for rope-2D
|
|
image_token_id: int
|
|
image_break_token_id: int
|
|
image_end_token_id: int
|
|
adapter_bias: bool = True
|
|
|
|
|
|
def _reshape_for_broadcast(freqs_cis: torch.Tensor,
|
|
x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
freqs_cis: complex - (seq_len, head_dim / 2)
|
|
x: complex - (bsz, seq_len, head_dim / 2)
|
|
"""
|
|
ndim = x.ndim
|
|
assert ndim > 1
|
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
|
|
freqs_cis.shape,
|
|
(x.shape[1], x.shape[-1]),
|
|
)
|
|
shape = [
|
|
d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
|
|
]
|
|
return freqs_cis.view(*shape)
|
|
|
|
|
|
def precompute_freqs_cis_2d(
|
|
dim: int,
|
|
height: int,
|
|
width: int,
|
|
theta: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
|
|
to be indexed by (height, width) position tuples
|
|
"""
|
|
# (dim / 2) frequency bases
|
|
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
|
|
|
|
h = torch.arange(height, device=freqs.device)
|
|
w = torch.arange(width, device=freqs.device)
|
|
|
|
freqs_h = torch.outer(h, freqs[::2]).float()
|
|
freqs_w = torch.outer(w, freqs[1::2]).float()
|
|
freqs_2d = torch.cat(
|
|
[
|
|
freqs_h[:, None, :].repeat(1, width, 1),
|
|
freqs_w[None, :, :].repeat(height, 1, 1),
|
|
],
|
|
dim=-1,
|
|
)
|
|
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
|
|
|
|
|
|
def apply_rotary_emb_vit(
|
|
xq: torch.Tensor,
|
|
xk: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
|
assert freqs_cis.dtype == torch.complex64
|
|
freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
|
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
assert args.intermediate_size is not None
|
|
self.w1 = nn.Linear(args.hidden_size,
|
|
args.intermediate_size,
|
|
bias=False)
|
|
self.w2 = nn.Linear(args.intermediate_size,
|
|
args.hidden_size,
|
|
bias=False)
|
|
self.w3 = nn.Linear(args.hidden_size,
|
|
args.intermediate_size,
|
|
bias=False)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
assert not args.hidden_size % args.num_attention_heads
|
|
self.n_heads = args.num_attention_heads
|
|
self.head_dim = args.hidden_size // args.num_attention_heads
|
|
|
|
self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
batch, patches, _ = x.shape
|
|
|
|
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
|
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
|
|
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
|
|
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
|
|
|
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
|
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
|
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
|
return self.wo(out)
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
self.attention = Attention(args)
|
|
self.feed_forward = FeedForward(args)
|
|
self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
|
|
self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
r = self.attention.forward(self.attention_norm(x),
|
|
mask=mask,
|
|
freqs_cis=freqs_cis)
|
|
h = x + r
|
|
r = self.feed_forward.forward(self.ffn_norm(h))
|
|
out = h + r
|
|
return out
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
self.layers = torch.nn.ModuleList()
|
|
for _ in range(args.num_hidden_layers):
|
|
self.layers.append(TransformerBlock(args))
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
freqs_cis: Optional[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
for layer in self.layers:
|
|
x = layer(x, mask=mask, freqs_cis=freqs_cis)
|
|
return x
|
|
|
|
|
|
def position_meshgrid(patch_embeds_list: List[torch.Tensor], ) -> torch.Tensor:
|
|
positions = torch.cat([
|
|
torch.stack(
|
|
torch.meshgrid(
|
|
torch.arange(p.shape[-2]),
|
|
torch.arange(p.shape[-1]),
|
|
indexing="ij",
|
|
),
|
|
dim=-1,
|
|
).reshape(-1, 2) for p in patch_embeds_list
|
|
])
|
|
return positions
|
|
|
|
|
|
class VisionTransformer(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
self.patch_conv = nn.Conv2d(
|
|
in_channels=args.num_channels,
|
|
out_channels=args.hidden_size,
|
|
kernel_size=args.patch_size,
|
|
stride=args.patch_size,
|
|
bias=False,
|
|
)
|
|
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
|
|
self.transformer = Transformer(args)
|
|
|
|
head_dim = self.args.hidden_size // self.args.num_attention_heads
|
|
assert head_dim % 2 == 0, "ROPE requires even head_dim"
|
|
self._freqs_cis: Optional[torch.Tensor] = None
|
|
|
|
@property
|
|
def max_patches_per_side(self) -> int:
|
|
return self.args.image_size // self.args.patch_size
|
|
|
|
@property
|
|
def device(self) -> torch.types.Device:
|
|
return next(self.parameters()).device
|
|
|
|
@property
|
|
def dtype(self) -> torch.dtype:
|
|
return next(self.parameters()).dtype
|
|
|
|
@property
|
|
def freqs_cis(self) -> torch.Tensor:
|
|
if self._freqs_cis is None:
|
|
self._freqs_cis = precompute_freqs_cis_2d(
|
|
dim=self.args.hidden_size // self.args.num_attention_heads,
|
|
height=self.max_patches_per_side,
|
|
width=self.max_patches_per_side,
|
|
theta=self.args.rope_theta,
|
|
)
|
|
|
|
if self._freqs_cis.device != self.device:
|
|
self._freqs_cis = self._freqs_cis.to(device=self.device)
|
|
|
|
return self._freqs_cis
|
|
|
|
def forward(
|
|
self,
|
|
images: List[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
images: list of N_img images of variable sizes,
|
|
each of shape (C, H, W)
|
|
Returns:
|
|
image_features: tensor of token features for
|
|
all tokens of all images of shape (N_toks, D)
|
|
"""
|
|
# pass images through initial convolution independently
|
|
patch_embeds_list = [
|
|
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
|
|
]
|
|
|
|
# flatten to a single sequence
|
|
patch_embeds = torch.cat(
|
|
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
|
|
patch_embeds = self.ln_pre(patch_embeds)
|
|
|
|
# positional embeddings
|
|
positions = position_meshgrid(patch_embeds_list).to(self.device)
|
|
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
|
|
|
|
# pass through Transformer with a block diagonal mask delimiting images
|
|
if USE_XFORMERS_OPS:
|
|
mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
|
else:
|
|
raise ImportError("Xformers is required for Pixtral inference "
|
|
"with the Mistral format")
|
|
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
|
|
|
|
# remove batch dimension of the single sequence
|
|
return out.squeeze(0)
|
|
|
|
|
|
class VisionLanguageAdapter(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs, dim: int):
|
|
super().__init__()
|
|
assert isinstance(args, VisionEncoderArgs)
|
|
self.w_in = nn.Linear(
|
|
args.hidden_size,
|
|
dim,
|
|
bias=args.adapter_bias,
|
|
)
|
|
self.gelu = nn.GELU()
|
|
self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.w_out(self.gelu(self.w_in(x)))
|
|
|
|
|
|
#### HF Transformers version of Pixtral ####
|
|
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
|
|
# This model follows the Llava family, meaning image embeddings are placed
|
|
# instead of the `[IMG]` token placeholders.
|
|
# The model uses [`PixtralVisionModel`] for its vision encoder,
|
|
# and [`MistralForCausalLM`] for its language decoder.
|
|
|
|
|
|
def get_pixtral_hf_patch_grid_length(*, image_size: int,
|
|
patch_size: int) -> int:
|
|
# Since interpolation is applied, the image size need not be divisible
|
|
# assert image_size % patch_size == 0
|
|
return image_size // patch_size
|
|
|
|
|
|
def get_pixtral_hf_image_feature_size(
|
|
*,
|
|
image_size: int,
|
|
patch_size: int,
|
|
) -> int:
|
|
grid_length = get_pixtral_hf_patch_grid_length(
|
|
image_size=image_size,
|
|
patch_size=patch_size,
|
|
)
|
|
|
|
# Consider the image_break_token
|
|
return (grid_length + 1) * grid_length
|
|
|
|
|
|
def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
|
|
grid_length = get_pixtral_hf_patch_grid_length(
|
|
image_size=hf_config.image_size,
|
|
patch_size=hf_config.patch_size,
|
|
)
|
|
|
|
# Consider the image_break_token
|
|
return (grid_length + 1) * grid_length
|
|
|
|
|
|
def dummy_image_for_pixtral_hf(
|
|
hf_config: PixtralVisionConfig,
|
|
num_images: int,
|
|
*,
|
|
image_width_override: Optional[int] = None,
|
|
image_height_override: Optional[int] = None,
|
|
):
|
|
width = height = hf_config.image_size
|
|
if image_width_override is not None:
|
|
width = image_width_override
|
|
if image_height_override is not None:
|
|
height = image_height_override
|
|
|
|
image = Image.new("RGB", (width, height), color=0)
|
|
return {"image": image if num_images == 1 else [image] * num_images}
|
|
|
|
|
|
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
|
|
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180
|
|
def get_pixtral_hf_image_feature_grid_size(
|
|
hf_config: PixtralVisionConfig,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> tuple[int, int]:
|
|
max_width = max_height = hf_config.image_size
|
|
patch_width = patch_height = hf_config.patch_size
|
|
|
|
ratio = max(image_width / max_width, image_height / max_height)
|
|
|
|
if ratio > 1:
|
|
image_width = int(math.ceil(image_width / ratio))
|
|
image_height = int(math.ceil(image_height / ratio))
|
|
|
|
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
|
(image_height, image_width),
|
|
(patch_height, patch_width),
|
|
) # type: ignore
|
|
|
|
return ncols, nrows
|
|
|
|
|
|
class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
|
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> int:
|
|
return get_pixtral_hf_image_feature_size(
|
|
image_size=self.vision_config.image_size,
|
|
patch_size=self.vision_config.patch_size,
|
|
)
|
|
|
|
def get_max_image_tokens(self) -> int:
|
|
return get_max_pixtral_hf_image_tokens(self.vision_config)
|
|
|
|
def get_image_size(self) -> int:
|
|
return self.vision_config.image_size
|
|
|
|
def get_patch_size(self) -> int:
|
|
return self.vision_config.patch_size
|
|
|
|
def get_patch_grid_length(self) -> int:
|
|
return get_pixtral_hf_patch_grid_length(
|
|
image_size=self.vision_config.image_size,
|
|
patch_size=self.vision_config.patch_size,
|
|
)
|
|
|
|
|
|
class PixtralHFMLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
assert config.intermediate_size is not None
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
input_size=config.hidden_size,
|
|
output_sizes=[config.intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj")
|
|
self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
|
|
output_size=config.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.down_proj")
|
|
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_and_mul(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class PixtralHFAttention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
assert not config.hidden_size % config.num_attention_heads
|
|
self.total_num_heads = config.num_attention_heads
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.n_heads = divide(config.num_attention_heads, tp_size)
|
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
hidden_size=config.hidden_size,
|
|
head_size=self.head_dim,
|
|
total_num_heads=self.total_num_heads,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
)
|
|
assert self.total_num_heads * self.head_dim == config.hidden_size
|
|
self.o_proj = RowParallelLinear(
|
|
input_size=config.hidden_size,
|
|
output_size=config.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
position_embeddings: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
batch, patches, _ = hidden_states.size()
|
|
|
|
qkv_states, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv_states.chunk(3, dim=-1)
|
|
|
|
# Transpose q and k to apply HF's Rotary Position Embedding
|
|
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
|
|
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
|
|
v = v.view(batch, patches, self.n_heads, self.head_dim)
|
|
cos, sin = position_embeddings
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
|
|
|
|
if USE_XFORMERS_OPS:
|
|
# Transpose q and k back for attention
|
|
q = q.transpose(1, 2).contiguous()
|
|
k = k.transpose(1, 2).contiguous()
|
|
|
|
out = xops.memory_efficient_attention(q,
|
|
k,
|
|
v,
|
|
attn_bias=attention_mask)
|
|
else:
|
|
v = v.transpose(1, 2)
|
|
out = nn.functional.scaled_dot_product_attention(
|
|
q, k, v, attn_mask=attention_mask)
|
|
out = out.transpose(1, 2)
|
|
|
|
out = out.view(batch, patches, self.n_heads * self.head_dim)
|
|
attn_output, _ = self.o_proj(out)
|
|
|
|
return attn_output, None
|
|
|
|
|
|
class PixtralHFTransformerBlock(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
|
self.attention = PixtralHFAttention(config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attention")
|
|
self.feed_forward = PixtralHFMLP(config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.feed_forward")
|
|
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
position_embeddings: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
r, _ = self.attention.forward(self.attention_norm(hidden_states),
|
|
attention_mask=attention_mask,
|
|
position_embeddings=position_embeddings)
|
|
h = hidden_states + r
|
|
r = self.feed_forward.forward(self.ffn_norm(h))
|
|
out = h + r
|
|
return out
|
|
|
|
|
|
class PixtralHFTransformer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
num_hidden_layers_override: Optional[int] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
if num_hidden_layers_override is None:
|
|
num_hidden_layers = config.num_hidden_layers
|
|
else:
|
|
num_hidden_layers = num_hidden_layers_override
|
|
|
|
self.layers = nn.ModuleList([
|
|
PixtralHFTransformerBlock(config=config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.layers.{layer_idx}")
|
|
for layer_idx in range(num_hidden_layers)
|
|
])
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
position_embeddings: torch.Tensor,
|
|
return_all_hidden_states: bool,
|
|
) -> torch.Tensor:
|
|
hidden_states_pool = [x]
|
|
|
|
for layer in self.layers:
|
|
x = layer(x, attention_mask, position_embeddings)
|
|
if return_all_hidden_states:
|
|
hidden_states_pool.append(x)
|
|
# If we have multiple feature sample layers, we return all hidden
|
|
# states in order and grab the ones we need by index.
|
|
if return_all_hidden_states:
|
|
return hidden_states_pool
|
|
return x
|
|
|
|
|
|
class PixtralHFVisionModel(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
num_hidden_layers_override: Optional[int] = None,
|
|
require_post_norm: Optional[bool] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
|
|
self.patch_conv = nn.Conv2d(
|
|
in_channels=config.num_channels,
|
|
out_channels=config.hidden_size,
|
|
kernel_size=config.patch_size,
|
|
stride=config.patch_size,
|
|
bias=False,
|
|
)
|
|
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
|
|
self.transformer = PixtralHFTransformer(
|
|
config,
|
|
quant_config,
|
|
num_hidden_layers_override=num_hidden_layers_override,
|
|
prefix=f"{prefix}.transformer",
|
|
)
|
|
|
|
num_hidden_layers = config.num_hidden_layers
|
|
if len(self.transformer.layers) > config.num_hidden_layers:
|
|
raise ValueError(
|
|
f"The original encoder only has {num_hidden_layers} "
|
|
f"layers, but you requested {len(self.transformer.layers)} "
|
|
"layers.")
|
|
|
|
if require_post_norm is True:
|
|
msg = "PixtralHFVisionModel does not have post-layernorm"
|
|
raise ValueError(msg)
|
|
|
|
self.dtype = next(self.parameters()).dtype
|
|
self.device = next(self.parameters()).device
|
|
self.patch_positional_embedding = PixtralRotaryEmbedding(
|
|
config, self.device)
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: List[torch.Tensor],
|
|
feature_sample_layers: Optional[list[int]] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
pixel_values: Each image to be processed will be a separate tensor
|
|
in pixel_values. This means it will be a list of tensors
|
|
because multiple requests batched can have multiple images,
|
|
each with their own shape potentially
|
|
feature_sample_layers: Layer indices whose features should be
|
|
concatenated and used as the visual encoder output. If none
|
|
are provided, the last layer is used.
|
|
|
|
Returns:
|
|
image_features: tensor of token features for
|
|
all tokens of all images of shape (N_toks, D)
|
|
"""
|
|
# pass images through initial convolution independently
|
|
patch_embeds_list = [
|
|
self.patch_conv(img.unsqueeze(0).to(self.dtype))
|
|
for img in pixel_values
|
|
]
|
|
|
|
patch_embeds = [
|
|
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
|
|
]
|
|
embed_sizes = [p.shape[1] for p in patch_embeds]
|
|
|
|
# flatten to a single sequence
|
|
patch_embeds = torch.cat(patch_embeds, dim=1)
|
|
patch_embeds = self.ln_pre(patch_embeds)
|
|
|
|
# positional embeddings
|
|
position_ids = position_ids_in_meshgrid(
|
|
patch_embeds_list,
|
|
max_width=self.config.image_size // self.config.patch_size).to(
|
|
self.device)
|
|
position_embedding = self.patch_positional_embedding(
|
|
patch_embeds, position_ids)
|
|
|
|
if USE_XFORMERS_OPS:
|
|
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
|
else:
|
|
from transformers.models.pixtral.modeling_pixtral import (
|
|
generate_block_attention_mask)
|
|
attention_mask = generate_block_attention_mask(
|
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
|
patch_embeds)
|
|
|
|
return_all_hidden_states = feature_sample_layers is not None
|
|
out = self.transformer(
|
|
patch_embeds,
|
|
attention_mask,
|
|
position_embedding,
|
|
return_all_hidden_states=return_all_hidden_states)
|
|
|
|
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
|
|
self.config.num_hidden_layers)
|
|
|
|
# squeeze dim 0 and split into separate tensors for each image
|
|
out = torch.split(torch.squeeze(out), embed_sizes)
|
|
return out
|
|
|
|
# (TODO) Add prefix argument for filtering out weights to be loaded
|
|
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
|
def load_weights(self, weights: Iterable[Tuple[str,
|
|
torch.Tensor]]) -> Set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
(".qkv_proj", ".q_proj", "q"),
|
|
(".qkv_proj", ".k_proj", "k"),
|
|
(".qkv_proj", ".v_proj", "v"),
|
|
(".gate_up_proj", ".gate_proj", 0),
|
|
(".gate_up_proj", ".up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: Set[str] = set()
|
|
layer_count = len(self.transformer.layers)
|
|
|
|
for name, loaded_weight in weights:
|
|
# omit layers when num_hidden_layers_override is set
|
|
if name.startswith("transformer.layers"):
|
|
layer_idx = int(name.split(".")[2])
|
|
if layer_idx >= layer_count:
|
|
continue
|
|
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|