1640 lines
56 KiB
Python
1640 lines
56 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import math
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from dataclasses import dataclass
|
|
from functools import cached_property, partial
|
|
from typing import List, Optional, Set, Tuple, TypedDict, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
|
|
TensorType)
|
|
from transformers.image_utils import ImageInput
|
|
from transformers.tokenization_utils_base import TextInput
|
|
|
|
from vllm.attention import Attention
|
|
from vllm.attention.layer import MultiHeadAttention
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
split_tensor_along_last_dim,
|
|
tensor_model_parallel_all_gather)
|
|
from vllm.model_executor import SamplingMetadata
|
|
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
|
|
SiluAndMul)
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead, VocabParallelEmbedding)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
|
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
|
MultiModalDataItems)
|
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|
BaseProcessingInfo, PromptIndexTargets,
|
|
PromptInsertion, PromptUpdate)
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
|
SupportsMultiModal, SupportsPP, SupportsQuant)
|
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
|
is_pp_missing_parameter,
|
|
make_empty_intermediate_tensors_factory, make_layers,
|
|
maybe_prefix, merge_multimodal_embeddings)
|
|
from .vision import scatter_patch_features, select_patch_features
|
|
|
|
# TODO: hard-coded for now. Consider making it configurable.
|
|
VIT_LAYERS = [-2, -9]
|
|
NUM_PREFIX_TOKENS = 1
|
|
ADDITIONAL_VOCAB_SIZE = 128
|
|
IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
IM_COL_TOKEN = "<im_col>"
|
|
IM_START_TOKEN = "<im_start>"
|
|
IM_END_TOKEN = "<im_end>"
|
|
POOLING_SIZE = 2
|
|
|
|
|
|
class MolmoImageInputs(TypedDict):
|
|
images: Union[torch.Tensor, list[torch.Tensor]]
|
|
"""Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`"""
|
|
|
|
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
|
|
"""Shape: `(batch_size * num_images, num_crops, num_patch)`"""
|
|
|
|
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
|
"""
|
|
A boolean mask indicating which image features correspond
|
|
to patch tokens.
|
|
|
|
Shape: `(batch_size * num_images, num_crops, num_patch)`
|
|
"""
|
|
|
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
|
"""
|
|
A boolean mask indicating which image embeddings correspond
|
|
to patch tokens.
|
|
|
|
Shape: `(batch_size * num_images, num_embeds)`
|
|
"""
|
|
|
|
num_crops: torch.Tensor
|
|
"""Shape: `(batch_size * num_images)`"""
|
|
|
|
|
|
@dataclass
|
|
class VisionBackboneConfig:
|
|
image_default_input_size: Tuple[int, int] = (336, 336)
|
|
image_patch_size: int = 14
|
|
image_pos_patch_size: int = 14
|
|
image_emb_dim: int = 1024
|
|
image_num_heads: int = 16
|
|
image_num_key_value_heads: int = 16
|
|
image_num_layers: int = 23
|
|
image_mlp_dim: int = 4096
|
|
image_mlp_activations: str = "quick_gelu"
|
|
image_num_pos: int = 577
|
|
image_norm_eps: float = 1e-5
|
|
|
|
def __post_init__(self):
|
|
self.image_default_input_size = tuple(
|
|
self.image_default_input_size) # type: ignore[assignment]
|
|
|
|
@property
|
|
def image_num_patch(self):
|
|
h, w = self.image_default_input_size
|
|
return h // self.image_patch_size, w // self.image_patch_size
|
|
|
|
|
|
class ViTMLP(nn.Module):
|
|
"""MLP used in Vision Transformer."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: VisionBackboneConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
):
|
|
super().__init__()
|
|
self.w1 = ColumnParallelLinear(
|
|
config.image_emb_dim,
|
|
config.image_mlp_dim,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
)
|
|
# Activation function.
|
|
assert config.image_mlp_activations == "quick_gelu"
|
|
self.act = QuickGELU()
|
|
self.w2 = RowParallelLinear(
|
|
config.image_mlp_dim,
|
|
config.image_emb_dim,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x, _ = self.w1(x)
|
|
x = self.act(x)
|
|
x, _ = self.w2(x)
|
|
return x
|
|
|
|
|
|
class MultiHeadDotProductAttention(nn.Module):
|
|
"""Multi-head attention used in Vision Transformer."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: VisionBackboneConfig,
|
|
use_bias: bool = True,
|
|
nlayers: int = 1,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.hidden_size = config.image_emb_dim
|
|
self.total_num_heads = config.image_num_heads
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
assert self.hidden_size % self.total_num_heads == 0
|
|
assert self.total_num_heads % tp_size == 0
|
|
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
self.head_dim = self.hidden_size // self.total_num_heads
|
|
|
|
self.total_num_kv_heads = config.image_num_key_value_heads
|
|
if self.total_num_kv_heads >= tp_size:
|
|
assert self.total_num_kv_heads % tp_size == 0
|
|
else:
|
|
assert tp_size % self.total_num_kv_heads == 0
|
|
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
|
|
self.wq = ColumnParallelLinear(
|
|
nlayers * self.hidden_size,
|
|
self.total_num_heads * self.head_dim,
|
|
bias=use_bias,
|
|
quant_config=quant_config,
|
|
)
|
|
self.wk = ColumnParallelLinear(
|
|
nlayers * self.hidden_size,
|
|
self.total_num_kv_heads * self.head_dim,
|
|
bias=use_bias,
|
|
quant_config=quant_config,
|
|
)
|
|
self.wv = ColumnParallelLinear(
|
|
nlayers * self.hidden_size,
|
|
self.total_num_kv_heads * self.head_dim,
|
|
bias=use_bias,
|
|
quant_config=quant_config,
|
|
)
|
|
self.wo = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
self.hidden_size,
|
|
bias=use_bias,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
self.scale = self.head_dim**-0.5
|
|
self.attn = MultiHeadAttention(self.num_heads,
|
|
self.head_dim,
|
|
self.scale,
|
|
num_kv_heads=self.num_kv_heads)
|
|
|
|
def forward(self,
|
|
inputs_q: torch.Tensor,
|
|
inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
if inputs_kv is not None:
|
|
inputs_k = inputs_kv
|
|
inputs_v = inputs_kv
|
|
else:
|
|
inputs_k = inputs_q
|
|
inputs_v = inputs_q
|
|
|
|
xq, _ = self.wq(inputs_q)
|
|
xk, _ = self.wk(inputs_k)
|
|
xv, _ = self.wv(inputs_v)
|
|
|
|
output = self.attn(xq, xk, xv)
|
|
output, _ = self.wo(output)
|
|
|
|
return output
|
|
|
|
|
|
class ResidualAttentionBlock(nn.Module):
|
|
"""Residual attention block used in Vision Transformer."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: VisionBackboneConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
):
|
|
super().__init__()
|
|
self.attention = MultiHeadDotProductAttention(
|
|
config, quant_config=quant_config)
|
|
self.feed_forward = ViTMLP(config, quant_config)
|
|
self.attention_norm = nn.LayerNorm(
|
|
config.image_emb_dim,
|
|
eps=config.image_norm_eps,
|
|
)
|
|
self.ffn_norm = nn.LayerNorm(
|
|
config.image_emb_dim,
|
|
eps=config.image_norm_eps,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = x + self.attention(self.attention_norm(x))
|
|
x = x + self.feed_forward(self.ffn_norm(x))
|
|
return x
|
|
|
|
|
|
class BlockCollection(nn.Module):
|
|
"""Collection of residual attention blocks used in Vision Transformer."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: VisionBackboneConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
):
|
|
super().__init__()
|
|
self.resblocks = nn.ModuleList([
|
|
ResidualAttentionBlock(config, quant_config)
|
|
for _ in range(config.image_num_layers)
|
|
])
|
|
|
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
hidden_states = []
|
|
for r in self.resblocks:
|
|
x = r(x)
|
|
hidden_states.append(x)
|
|
return hidden_states
|
|
|
|
|
|
def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor:
|
|
return token.view(1, 1, -1).expand(batch_size, -1, -1)
|
|
|
|
|
|
class VisionTransformer(nn.Module):
|
|
"""Vision Transformer used in Vision Backbone."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: VisionBackboneConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
):
|
|
super().__init__()
|
|
scale = config.image_emb_dim**-0.5
|
|
self.patch_num = config.image_num_patch
|
|
self.class_embedding = nn.Parameter(
|
|
torch.randn(config.image_emb_dim) * scale)
|
|
self.num_prefix_tokens: int = NUM_PREFIX_TOKENS
|
|
self.positional_embedding = nn.Parameter(
|
|
torch.randn(config.image_num_pos, config.image_emb_dim) * scale)
|
|
image_patch_size = config.image_patch_size
|
|
self.patch_embedding = nn.Linear(
|
|
image_patch_size * image_patch_size * 3,
|
|
config.image_emb_dim,
|
|
bias=False,
|
|
)
|
|
self.pre_ln = nn.LayerNorm(config.image_emb_dim,
|
|
eps=config.image_norm_eps)
|
|
self.transformer = BlockCollection(config, quant_config)
|
|
|
|
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
|
|
cls_emb = self.positional_embedding[0:1]
|
|
pos_emb = self.positional_embedding[1:]
|
|
|
|
pos_emb = pos_emb.reshape(
|
|
(int(math.sqrt(pos_emb.shape[0])),
|
|
int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))
|
|
|
|
(patch_num_0, patch_num_1) = patch_num
|
|
|
|
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
|
|
# from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
|
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
|
|
pos_emb = F.interpolate(
|
|
pos_emb,
|
|
size=(patch_num_0, patch_num_1),
|
|
mode="bicubic",
|
|
align_corners=False,
|
|
antialias=True,
|
|
)
|
|
pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
|
|
|
|
pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
|
|
x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]],
|
|
dim=1).to(x.dtype)
|
|
return x
|
|
|
|
def forward(self,
|
|
x: torch.Tensor,
|
|
patch_num: Optional[int] = None) -> List[torch.Tensor]:
|
|
"""
|
|
: param x: (batch_size, num_patch, n_pixels)
|
|
"""
|
|
if patch_num is None:
|
|
patch_num = self.patch_num
|
|
B, N, D = x.shape
|
|
|
|
x = self.patch_embedding(x)
|
|
|
|
# class embeddings and positional embeddings
|
|
x = torch.cat(
|
|
[_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x],
|
|
dim=1)
|
|
x = self.add_pos_emb(x, patch_num)
|
|
|
|
x = self.pre_ln(x)
|
|
|
|
hidden_states = self.transformer(x)
|
|
return hidden_states
|
|
|
|
|
|
class MolmoAttention(nn.Module):
|
|
"""Molmo's LLM attention."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = config.num_attention_heads
|
|
|
|
assert self.hidden_size % self.total_num_heads == 0
|
|
assert self.total_num_heads % self.tp_size == 0
|
|
|
|
self.num_heads = self.total_num_heads // self.tp_size
|
|
self.total_num_kv_heads = config.num_key_value_heads \
|
|
or self.total_num_heads
|
|
if self.total_num_kv_heads >= self.tp_size:
|
|
assert self.total_num_kv_heads % self.tp_size == 0
|
|
else:
|
|
assert self.tp_size % self.total_num_kv_heads == 0
|
|
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
|
|
self.head_dim = self.hidden_size // self.total_num_heads
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.rope_theta = config.rope_theta
|
|
|
|
# Attention input projection. Projects x -> (q, k, v)
|
|
self.qkv_proj = QKVParallelLinear(
|
|
self.hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_kv_heads,
|
|
bias=config.qkv_bias,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
self.tp_rank: Optional[int] = None
|
|
self.k_norm: Optional[nn.Module] = None
|
|
self.q_norm: Optional[nn.Module] = None
|
|
if config.attention_layer_norm:
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
|
|
eps=config.layer_norm_eps)
|
|
self.q_norm = RMSNorm(config.hidden_size,
|
|
eps=config.layer_norm_eps)
|
|
|
|
# Rotary embeddings.
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
rotary_dim=self.head_dim,
|
|
max_position=self.max_position_embeddings,
|
|
base=self.rope_theta,
|
|
)
|
|
self.scaling = self.head_dim**-0.5
|
|
self.attn = Attention(self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn")
|
|
|
|
# Attention output projection.
|
|
self.o_proj = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
def _apply_qk_norm(self, q: torch.Tensor,
|
|
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if self.tp_size > 1:
|
|
q = tensor_model_parallel_all_gather(q.contiguous())
|
|
k = tensor_model_parallel_all_gather(k.contiguous())
|
|
q = self.q_norm.forward_native(q)
|
|
k = self.k_norm.forward_native(k)
|
|
if self.tp_size > 1:
|
|
splitter = partial(split_tensor_along_last_dim,
|
|
num_partitions=self.tp_size)
|
|
q = splitter(q)[self.tp_rank]
|
|
k = splitter(k)[self.tp_rank]
|
|
return q, k
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
if self.q_norm is not None and self.k_norm is not None:
|
|
q, k = self._apply_qk_norm(q, k)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class LanguageModelMLP(nn.Module):
|
|
"""Molmo's LLM mlp."""
|
|
|
|
def __init__(self,
|
|
config: PretrainedConfig,
|
|
input_dim: Optional[int] = None,
|
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.intermediate_size // 2
|
|
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
input_dim or self.hidden_size,
|
|
[self.intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
)
|
|
# Activation function.
|
|
self.act_fn = MulAndSilu()
|
|
# Feed-forward output projection.
|
|
self.down_proj = RowParallelLinear(
|
|
self.intermediate_size,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class ImageProjectorMLP(nn.Module):
|
|
"""Molmo's image_projector mlp."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
input_dim: Optional[int] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.intermediate_size // 2
|
|
|
|
self.merged_linear = MergedColumnParallelLinear(
|
|
input_dim or self.hidden_size,
|
|
[self.intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
)
|
|
# Activation function.
|
|
self.act_fn = SiluAndMul()
|
|
|
|
# Feed-forward output projection.
|
|
self.down_proj = RowParallelLinear(
|
|
self.intermediate_size,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
gate_up, _ = self.merged_linear(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class MolmoDecoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
# Attention block.
|
|
self.self_attn = MolmoAttention(config,
|
|
cache_config,
|
|
quant_config,
|
|
prefix=f"{prefix}.self_attn")
|
|
|
|
# MLP block.
|
|
self.mlp = LanguageModelMLP(config, quant_config=quant_config)
|
|
|
|
# LayerNorm
|
|
assert config.layer_norm_type == "rms"
|
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.layer_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.layer_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
# Self Attention
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(
|
|
hidden_states, residual)
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
)
|
|
|
|
hidden_states, residual = self.post_attention_layernorm(
|
|
hidden_states, residual)
|
|
hidden_states = self.mlp(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
# Self Attention
|
|
residual = hidden_states
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
)
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
hidden_states = hidden_states + residual
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = hidden_states + residual
|
|
residual = None
|
|
return hidden_states, residual
|
|
|
|
|
|
class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
|
packed_modules_mapping = {"merged_linear": ["gate_proj", "up_proj"]}
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
vision_config: VisionBackboneConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.vit_layers = VIT_LAYERS
|
|
self.image_num_patch = vision_config.image_num_patch
|
|
self.llm_patches_per_crop = (
|
|
(self.image_num_patch[0] + 1) // POOLING_SIZE,
|
|
(self.image_num_patch[1] + 1) // POOLING_SIZE,
|
|
)
|
|
self.image_vit = VisionTransformer(vision_config,
|
|
quant_config=quant_config)
|
|
self.num_prefix_tokens = self.image_vit.num_prefix_tokens
|
|
assert self.num_prefix_tokens in {
|
|
0, 1
|
|
}, "Only 0 or 1 prefix tokens are supported"
|
|
self.image_pooling_2d = MultiHeadDotProductAttention(
|
|
vision_config,
|
|
nlayers=len(self.vit_layers),
|
|
quant_config=quant_config)
|
|
self.image_projector = ImageProjectorMLP(
|
|
config,
|
|
input_dim=vision_config.image_emb_dim,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
image_dim = vision_config.image_emb_dim * len(self.vit_layers)
|
|
self.pad_embed = nn.Parameter(torch.zeros((2, image_dim)))
|
|
|
|
@property
|
|
def dtype(self) -> torch.dtype:
|
|
return self.image_vit.patch_embedding.weight.dtype
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.image_vit.patch_embedding.weight.device
|
|
|
|
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
: param images: (batch_size, num_crops, num_patch, n_pixels)
|
|
"""
|
|
B, T, N, D = images.shape
|
|
|
|
mask = ~torch.all(
|
|
images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
|
|
|
|
images = images.view(B * T, N, D)
|
|
image_features = self.image_vit(images)
|
|
|
|
if self.vit_layers is not None:
|
|
features = []
|
|
for layer in self.vit_layers:
|
|
features.append(image_features[layer])
|
|
image_features = torch.cat(features, dim=-1)
|
|
else:
|
|
image_features = image_features[-1]
|
|
|
|
if self.num_prefix_tokens > 0:
|
|
image_features = image_features[:, 1:]
|
|
|
|
image_features = image_features * mask
|
|
image_features = image_features.view(B, T, N, -1)
|
|
|
|
return image_features
|
|
|
|
def forward(
|
|
self,
|
|
images: torch.Tensor,
|
|
image_masks: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
|
|
batch_size, num_image = images.shape[:2]
|
|
images = images.to(device=self.device, dtype=self.dtype)
|
|
image_features = self.encode_image(images)
|
|
|
|
og_dtype = image_features.dtype
|
|
assert image_masks is not None
|
|
pad_embed = self.pad_embed[:, None, None, None, :]
|
|
all_pad = image_masks == 0
|
|
partial_pad = torch.logical_and(
|
|
image_masks < 1,
|
|
torch.logical_not(all_pad)).to(dtype=torch.float32)
|
|
all_pad = all_pad.to(dtype=torch.float32)
|
|
image_features = image_features + pad_embed[0] * torch.unsqueeze(
|
|
all_pad, -1)
|
|
image_features = image_features + pad_embed[1] * torch.unsqueeze(
|
|
partial_pad, -1)
|
|
|
|
image_features = image_features.to(og_dtype)
|
|
|
|
image_features = image_features.reshape(
|
|
(batch_size, num_image) + self.image_num_patch + (-1, ), )
|
|
|
|
if (missing_w := self.image_num_patch[0] % POOLING_SIZE):
|
|
# Padding for image pooling (see below)
|
|
image_features = F.pad(
|
|
image_features,
|
|
(0, 0, 0, missing_w, 0, missing_w, 0, 0, 0, 0),
|
|
)
|
|
|
|
# image pooling
|
|
image_features = rearrange(
|
|
image_features,
|
|
'b n (h dh) (w dw) c -> (b n h w) (dh dw) c',
|
|
dh=POOLING_SIZE,
|
|
dw=POOLING_SIZE,
|
|
)
|
|
|
|
query = image_features.mean(-2, keepdim=True)
|
|
image_features = self.image_pooling_2d(query, image_features)
|
|
|
|
h, w = self.llm_patches_per_crop
|
|
image_features = image_features.view(batch_size, num_image, h * w, -1)
|
|
|
|
image_features = self.image_projector(image_features)
|
|
|
|
# image_features: (batch_size, num_image, num_patch, d_model)
|
|
return image_features
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str,
|
|
torch.Tensor]]) -> Set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("merged_linear", "gate_proj", 0),
|
|
("merged_linear", "up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: Set[str] = set()
|
|
|
|
for name, loaded_weight in weights:
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
@support_torch_compile
|
|
class MolmoModel(nn.Module, SupportsQuant):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.config = config
|
|
|
|
self.embedding_size = config.embedding_size or config.vocab_size
|
|
self.embedding_size += ADDITIONAL_VOCAB_SIZE
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
self.embedding_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \
|
|
else MolmoDecoderLayer
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
config.num_hidden_layers,
|
|
lambda prefix: decoder_layer(
|
|
config, cache_config, quant_config, prefix=prefix),
|
|
prefix=f"{prefix}.layers",
|
|
)
|
|
|
|
assert config.layer_norm_type == "rms"
|
|
self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
make_empty_intermediate_tensors_factory(
|
|
["hidden_states", "residual"], config.hidden_size))
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
|
|
# Apply blocks one-by-one.
|
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
|
hidden_states, residual = layer(
|
|
positions,
|
|
hidden_states,
|
|
residual,
|
|
)
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({
|
|
"hidden_states": hidden_states,
|
|
"residual": residual
|
|
})
|
|
if residual is not None:
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
else:
|
|
hidden_states = self.norm(hidden_states)
|
|
return hidden_states
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str,
|
|
torch.Tensor]]) -> Set[str]:
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: Set[str] = set()
|
|
|
|
for name, loaded_weight in weights:
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
def _lowest_multiple(x: int, k: int) -> int:
|
|
return (x // k) * k
|
|
|
|
|
|
def get_num_patches(
|
|
num_tiles: int,
|
|
*,
|
|
crop_patches: int,
|
|
left_margin: int,
|
|
right_margin: int,
|
|
pooling_size: int,
|
|
) -> int:
|
|
if num_tiles == 1:
|
|
return _lowest_multiple(crop_patches + pooling_size - 1, pooling_size)
|
|
|
|
crop_window_patches = crop_patches - (left_margin + right_margin)
|
|
|
|
left_num = _lowest_multiple(
|
|
crop_window_patches + left_margin + pooling_size - 1,
|
|
pooling_size,
|
|
)
|
|
middle_num = _lowest_multiple(
|
|
crop_window_patches + pooling_size - 1,
|
|
pooling_size,
|
|
)
|
|
right_num = _lowest_multiple(
|
|
crop_window_patches + right_margin + pooling_size - 1,
|
|
pooling_size,
|
|
)
|
|
|
|
return left_num + (num_tiles - 2) * middle_num + right_num
|
|
|
|
|
|
def get_patches_grid_size(
|
|
*,
|
|
tiling_h: int,
|
|
tiling_w: int,
|
|
crop_patches: int,
|
|
left_margin: int,
|
|
right_margin: int,
|
|
pooling_size: int,
|
|
) -> tuple[int, int]:
|
|
nrows = get_num_patches(
|
|
tiling_h,
|
|
crop_patches=crop_patches,
|
|
left_margin=left_margin,
|
|
right_margin=right_margin,
|
|
pooling_size=pooling_size,
|
|
)
|
|
ncols = get_num_patches(
|
|
tiling_w,
|
|
crop_patches=crop_patches,
|
|
left_margin=left_margin,
|
|
right_margin=right_margin,
|
|
pooling_size=pooling_size,
|
|
)
|
|
|
|
return nrows, ncols
|
|
|
|
|
|
def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]:
|
|
tilings = [(i, j) for i in range(1, max_num + 1)
|
|
for j in range(1, max_num + 1) if i * j <= max_num]
|
|
return sorted(tilings, key=lambda x: x[0] * x[1])
|
|
|
|
|
|
def select_tiling(
|
|
*,
|
|
height: int,
|
|
width: int,
|
|
patch_size: int,
|
|
max_num_patches: int,
|
|
):
|
|
tilings = get_candidate_tilings(max_num_patches)
|
|
candidate_tilings = np.array(tilings, dtype=np.int32)
|
|
candidate_resolutions = candidate_tilings * patch_size
|
|
|
|
original_size = np.array([height, width], dtype=np.float32)
|
|
required_scale_d = candidate_resolutions.astype(np.float32) / original_size
|
|
required_scale = required_scale_d.min(axis=-1, keepdims=True)
|
|
|
|
if (required_scale < 1).all():
|
|
ix = required_scale.argmax()
|
|
else:
|
|
ix = np.where(required_scale < 1.0, 10e9, required_scale).argmin()
|
|
|
|
return candidate_tilings[ix]
|
|
|
|
|
|
class MolmoProcessorWrapper:
|
|
"""
|
|
Wraps :class:`MolmoProcessor` so that it can be called directly.
|
|
|
|
The original definition can be found here:
|
|
https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py
|
|
"""
|
|
|
|
def __init__(self, processor: ProcessorMixin):
|
|
super().__init__()
|
|
|
|
self.processor = processor
|
|
|
|
@cached_property
|
|
def vocab(self) -> dict[str, int]:
|
|
return self.processor.tokenizer.vocab # type: ignore
|
|
|
|
@cached_property
|
|
def max_crops(self) -> int:
|
|
image_processor = self.processor.image_processor # type: ignore
|
|
|
|
max_crops = image_processor.max_crops
|
|
assert isinstance(max_crops, int)
|
|
|
|
return max_crops
|
|
|
|
@cached_property
|
|
def base_image_input_size(self) -> tuple[int, int]:
|
|
image_processor = self.processor.image_processor # type: ignore
|
|
|
|
base_image_input_size = image_processor.base_image_input_size
|
|
if isinstance(base_image_input_size, int):
|
|
return base_image_input_size, base_image_input_size
|
|
|
|
return tuple(base_image_input_size)
|
|
|
|
@cached_property
|
|
def image_patch_size(self) -> int:
|
|
image_processor = self.processor.image_processor # type: ignore
|
|
|
|
image_patch_size = image_processor.image_patch_size
|
|
assert isinstance(image_patch_size, int)
|
|
|
|
return image_patch_size
|
|
|
|
@cached_property
|
|
def overlap_margins(self) -> tuple[int, int]:
|
|
image_processor = self.processor.image_processor # type: ignore
|
|
|
|
left_margin, right_margin = image_processor.overlap_margins
|
|
assert isinstance(left_margin, int)
|
|
assert isinstance(right_margin, int)
|
|
|
|
return left_margin, right_margin
|
|
|
|
@cached_property
|
|
def image_token_length_w(self) -> int:
|
|
image_processor = self.processor.image_processor # type: ignore
|
|
|
|
image_token_length_w = image_processor.image_token_length_w
|
|
assert isinstance(image_token_length_w, int)
|
|
|
|
return image_token_length_w
|
|
|
|
@cached_property
|
|
def image_token_length_h(self) -> int:
|
|
image_processor = self.processor.image_processor # type: ignore
|
|
|
|
image_token_length_h = image_processor.image_token_length_h
|
|
assert isinstance(image_token_length_h, int)
|
|
|
|
return image_token_length_h
|
|
|
|
@property
|
|
def message_format(self) -> Optional[str]:
|
|
return "role"
|
|
|
|
@property
|
|
def always_start_with_space(self) -> bool:
|
|
return True
|
|
|
|
@cached_property
|
|
def image_patch_id(self) -> int:
|
|
return self.vocab[IMAGE_PATCH_TOKEN]
|
|
|
|
@cached_property
|
|
def im_col_id(self) -> int:
|
|
return self.vocab[IM_COL_TOKEN]
|
|
|
|
@cached_property
|
|
def im_start_id(self) -> int:
|
|
return self.vocab[IM_START_TOKEN]
|
|
|
|
@cached_property
|
|
def im_end_id(self) -> int:
|
|
return self.vocab[IM_END_TOKEN]
|
|
|
|
@property
|
|
def pooling_size(self) -> int:
|
|
return POOLING_SIZE
|
|
|
|
def select_tiling(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> tuple[int, int]:
|
|
max_crops = self.max_crops
|
|
left_margin, right_margin = self.overlap_margins
|
|
base_image_input_size = self.base_image_input_size
|
|
base_image_input_d = self.image_patch_size
|
|
|
|
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
|
|
crop_patches = base_image_input_size[0] // base_image_input_d
|
|
crop_window_patches = crop_patches - (right_margin + left_margin)
|
|
crop_window_size = crop_window_patches * base_image_input_d
|
|
tiling_h, tiling_w = select_tiling(
|
|
height=image_height - total_margin_pixels,
|
|
width=image_width - total_margin_pixels,
|
|
patch_size=crop_window_size,
|
|
max_num_patches=max_crops,
|
|
)
|
|
|
|
return tiling_w, tiling_h
|
|
|
|
def get_patches_grid_size(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> tuple[int, int]:
|
|
left_margin, right_margin = self.overlap_margins
|
|
base_image_input_size = self.base_image_input_size
|
|
base_image_input_d = self.image_patch_size
|
|
pooling_size = self.pooling_size
|
|
|
|
crop_patches = base_image_input_size[0] // base_image_input_d
|
|
tiling_w, tiling_h = self.select_tiling(
|
|
image_height=image_height,
|
|
image_width=image_width,
|
|
)
|
|
|
|
nrows, ncols = get_patches_grid_size(
|
|
tiling_h=tiling_h,
|
|
tiling_w=tiling_w,
|
|
crop_patches=crop_patches,
|
|
left_margin=left_margin,
|
|
right_margin=right_margin,
|
|
pooling_size=pooling_size,
|
|
)
|
|
|
|
return ncols, nrows
|
|
|
|
def __call__(
|
|
self,
|
|
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
|
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
|
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
**kwargs,
|
|
) -> BatchFeature:
|
|
outputs = self.processor.process( # type: ignore
|
|
text, images, **kwargs)
|
|
|
|
if images is None:
|
|
images = []
|
|
if not isinstance(images, list):
|
|
images = [images]
|
|
|
|
input_ids: torch.Tensor = outputs.pop("input_ids")
|
|
outputs["input_ids"] = input_ids.unsqueeze(0)
|
|
|
|
image_input_idx = outputs.pop("image_input_idx", None)
|
|
if image_input_idx is not None:
|
|
feat_is_patch = image_input_idx >= 0
|
|
|
|
input_is_embed = torch.isin(
|
|
input_ids,
|
|
torch.tensor([
|
|
self.image_patch_id,
|
|
self.im_col_id,
|
|
self.im_start_id,
|
|
self.im_end_id,
|
|
]),
|
|
)
|
|
embed_ids = input_ids[input_is_embed]
|
|
embed_is_patch = embed_ids == self.image_patch_id
|
|
assert embed_is_patch.sum() == feat_is_patch.sum()
|
|
|
|
# image_tokens = extra_joint + joint
|
|
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
|
|
embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
|
|
embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
|
|
assert len(embed_start) == len(embed_end) == len(images)
|
|
|
|
embed_is_patch = [
|
|
embed_is_patch[start:end + 1]
|
|
for start, end in zip(embed_start, embed_end)
|
|
]
|
|
|
|
tilings = [
|
|
self.select_tiling(
|
|
image_width=image.size[0],
|
|
image_height=image.size[1],
|
|
) for image in images
|
|
]
|
|
# For each image: tiling_h * tiling_w + extra
|
|
num_crops = torch.tensor(tilings).prod(-1) + 1
|
|
assert num_crops.sum() == len(feat_is_patch)
|
|
|
|
outputs["feat_is_patch"] = feat_is_patch
|
|
outputs["embed_is_patch"] = embed_is_patch
|
|
outputs["num_crops"] = num_crops
|
|
outputs["img_patch_id"] = self.image_patch_id
|
|
|
|
return BatchFeature(outputs)
|
|
|
|
|
|
class MolmoProcessingInfo(BaseProcessingInfo):
|
|
|
|
def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper:
|
|
processor = self.ctx.get_hf_processor(**kwargs)
|
|
return MolmoProcessorWrapper(processor)
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
return {"image": None}
|
|
|
|
def get_mm_max_tokens_per_item(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> Mapping[str, int]:
|
|
return {"image": self.get_max_image_tokens()}
|
|
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
processor: Optional[MolmoProcessorWrapper],
|
|
) -> int:
|
|
if processor is None:
|
|
processor = self.get_hf_processor()
|
|
|
|
ncols, nrows = processor.get_patches_grid_size(
|
|
image_width=image_width,
|
|
image_height=image_height,
|
|
)
|
|
pooling_size = processor.pooling_size
|
|
|
|
base_image_input_size = processor.base_image_input_size
|
|
base_image_input_d = processor.image_patch_size
|
|
|
|
crop_patches = base_image_input_size[0] // base_image_input_d
|
|
|
|
per_row = ncols // pooling_size + 1
|
|
joint = per_row * (nrows // pooling_size) + 2
|
|
image_token_length = (crop_patches + pooling_size - 1) // pooling_size
|
|
resize = (image_token_length + 1) * image_token_length + 2
|
|
|
|
return resize + joint
|
|
|
|
def get_max_image_tokens(self) -> int:
|
|
target_width, target_height = self.get_image_size_with_most_features()
|
|
|
|
return self.get_num_image_tokens(
|
|
image_width=target_width,
|
|
image_height=target_height,
|
|
processor=None,
|
|
)
|
|
|
|
def get_image_size_with_most_features(self) -> ImageSize:
|
|
processor = self.get_hf_processor()
|
|
|
|
tilings = get_candidate_tilings(processor.max_crops)
|
|
base_h, base_w = processor.base_image_input_size
|
|
|
|
largest_feature_size, largest_feature_pinpoint = 0, None
|
|
for wr, hr in tilings:
|
|
width, height = base_w * wr, base_h * hr
|
|
|
|
feat_size = self.get_num_image_tokens(
|
|
image_width=width,
|
|
image_height=height,
|
|
processor=processor,
|
|
)
|
|
if feat_size > largest_feature_size:
|
|
largest_feature_size = feat_size
|
|
largest_feature_pinpoint = ImageSize(width=width,
|
|
height=height)
|
|
|
|
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
|
raise ValueError("Cannot have a largest feature size of 0!")
|
|
|
|
return largest_feature_pinpoint
|
|
|
|
|
|
class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
|
|
|
|
def get_dummy_processor_inputs(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> ProcessorInputs:
|
|
target_width, target_height = \
|
|
self.info.get_image_size_with_most_features()
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
mm_data = {
|
|
"image":
|
|
self._get_dummy_images(width=target_width,
|
|
height=target_height,
|
|
num_images=num_images)
|
|
}
|
|
|
|
return ProcessorInputs(
|
|
prompt_text="",
|
|
mm_data=mm_data,
|
|
)
|
|
|
|
|
|
class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
|
|
|
def _apply_hf_processor_tokens_only(
|
|
self,
|
|
prompt_tokens: list[int],
|
|
) -> list[int]:
|
|
processor = self.info.get_hf_processor()
|
|
|
|
# Apply the chat template to the tokens
|
|
tokens = processor.processor.get_tokens_input( # type: ignore
|
|
self.info.get_tokenizer().decode(prompt_tokens),
|
|
message_format=processor.message_format,
|
|
always_start_with_space=processor.always_start_with_space,
|
|
)
|
|
|
|
processed_data = self.info.ctx.call_hf_processor(
|
|
processor, # type: ignore
|
|
dict(tokens=tokens),
|
|
)
|
|
prompt_ids, = processed_data.pop("input_ids").tolist()
|
|
|
|
return prompt_ids
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
num_crops = hf_inputs.get("num_crops", torch.empty(0))
|
|
num_images = len(num_crops)
|
|
|
|
return dict(
|
|
images=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
|
image_masks=MultiModalFieldConfig.flat_from_sizes(
|
|
"image", num_crops),
|
|
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
|
"image", num_crops),
|
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
|
num_crops=MultiModalFieldConfig.batched("image"),
|
|
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
|
)
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargs,
|
|
) -> Sequence[PromptUpdate]:
|
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
|
|
image_token_length_w = processor.image_token_length_w
|
|
image_token_length_h = processor.image_token_length_h
|
|
pooling_size = processor.pooling_size
|
|
|
|
img_patch_id = processor.image_patch_id
|
|
img_col_id = processor.im_col_id
|
|
img_start_id = processor.im_start_id
|
|
img_end_id = processor.im_end_id
|
|
|
|
extra_row = [img_patch_id] * image_token_length_w + [img_col_id]
|
|
extra_joint = ([img_start_id] + extra_row * image_token_length_h +
|
|
[img_end_id])
|
|
|
|
def get_insertion_molmo(item_idx: int):
|
|
images = mm_items.get_items("image", ImageProcessorItems)
|
|
image_size = images.get_image_size(item_idx)
|
|
|
|
ncols, nrows = processor.get_patches_grid_size(
|
|
image_width=image_size.width,
|
|
image_height=image_size.height,
|
|
)
|
|
|
|
joint_row = ([img_patch_id] * ((ncols + 1) // pooling_size) +
|
|
[img_col_id])
|
|
joint = ([img_start_id] + joint_row *
|
|
((nrows + 1) // pooling_size) + [img_end_id])
|
|
|
|
image_tokens = extra_joint + joint
|
|
return image_tokens
|
|
|
|
return [
|
|
PromptInsertion(
|
|
modality="image",
|
|
target=PromptIndexTargets.prefix("<|endoftext|>"),
|
|
insertion=get_insertion_molmo,
|
|
)
|
|
]
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(MolmoMultiModalProcessor,
|
|
info=MolmoProcessingInfo,
|
|
dummy_inputs=MolmoDummyInputsBuilder)
|
|
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|
SupportsQuant):
|
|
hf_to_vllm_mapper = WeightsMapper(
|
|
orig_to_new_substr={
|
|
# vision backbone mapping
|
|
"image_projector.w1.": "image_projector.gate_proj.",
|
|
"image_projector.w3.": "image_projector.up_proj.",
|
|
"image_projector.w2.": "image_projector.down_proj.",
|
|
# language backbone mapping
|
|
"att_proj": "self_attn.qkv_proj",
|
|
"attn_out": "self_attn.o_proj",
|
|
"q_norm": "self_attn.q_norm",
|
|
"k_norm": "self_attn.k_norm",
|
|
"ff_proj": "mlp.gate_up_proj",
|
|
"ff_out": "mlp.down_proj",
|
|
"attn_norm": "input_layernorm",
|
|
"ff_norm": "post_attention_layernorm",
|
|
},
|
|
orig_to_new_prefix={
|
|
# vision backbone mapping
|
|
"model.vision_backbone.": "vision_backbone.",
|
|
# language backbone mapping
|
|
"model.transformer.blocks.": "model.layers.",
|
|
"model.transformer.ln_f.": "model.norm.",
|
|
# lm_head is renamed to model.transformer.mlp.down_proj firstly,
|
|
# we need to run a second renaming for it
|
|
"model.transformer.mlp.down_proj.": "lm_head.",
|
|
},
|
|
)
|
|
|
|
packed_modules_mapping = {
|
|
"qkv_proj": ["qkv_proj"],
|
|
"gate_up_proj": ["gate_up_proj"], # language model
|
|
"merged_linear": ["gate_proj", "up_proj"] # image_projector
|
|
}
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
lora_config = vllm_config.lora_config
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
self.lora_config = lora_config
|
|
|
|
vision_config = VisionBackboneConfig()
|
|
self.vision_backbone = MolmoVisionBackbone(config, vision_config,
|
|
quant_config)
|
|
self.model = MolmoModel(vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "model"))
|
|
self.img_patch_id = None
|
|
|
|
if self.config.weight_tying:
|
|
self.lm_head = self.model.transformer.wte
|
|
else:
|
|
self.lm_head = ParallelLMHead(
|
|
config.embedding_size or config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
self.logits_processor = LogitsProcessor(config.embedding_size
|
|
or config.vocab_size)
|
|
self.sampler = get_sampler()
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors)
|
|
|
|
def _parse_and_validate_image_input(
|
|
self,
|
|
**kwargs: object,
|
|
) -> Optional[MolmoImageInputs]:
|
|
images = kwargs.pop("images", None)
|
|
if images is None:
|
|
return None
|
|
|
|
if not isinstance(images, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of images. "
|
|
f"Got type: {type(images)}")
|
|
|
|
image_masks = kwargs.pop("image_masks", None)
|
|
if not (image_masks is None or isinstance(image_masks,
|
|
(torch.Tensor, list))):
|
|
raise ValueError("Incorrect type of image_masks. "
|
|
f"Got type: {type(image_masks)}")
|
|
|
|
feat_is_patch = kwargs.pop("feat_is_patch", None)
|
|
if not isinstance(feat_is_patch, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of feat_is_patch. "
|
|
f"Got type: {type(feat_is_patch)}")
|
|
|
|
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of embed_is_patch. "
|
|
f"Got type: {type(embed_is_patch)}")
|
|
|
|
num_crops = kwargs.pop("num_crops", None)
|
|
if not isinstance(num_crops, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of num_crops. "
|
|
f"Got type: {type(num_crops)}")
|
|
|
|
img_patch_id = kwargs.pop("img_patch_id", None)
|
|
if not isinstance(img_patch_id, torch.Tensor):
|
|
raise ValueError("Incorrect type of img_patch_id. "
|
|
f"Got type: {type(img_patch_id)}")
|
|
self.img_patch_id = img_patch_id.flatten().unique().item()
|
|
|
|
embed_is_patch = flatten_bn(embed_is_patch)
|
|
num_crops = flatten_bn(num_crops, concat=True)
|
|
|
|
return MolmoImageInputs(
|
|
images=images,
|
|
image_masks=image_masks,
|
|
feat_is_patch=feat_is_patch,
|
|
embed_is_patch=embed_is_patch,
|
|
num_crops=num_crops,
|
|
)
|
|
|
|
def _process_image_input(
|
|
self,
|
|
image_input: MolmoImageInputs,
|
|
) -> list[torch.Tensor]:
|
|
images = image_input["images"]
|
|
image_masks = image_input["image_masks"]
|
|
feat_is_patch = image_input["feat_is_patch"]
|
|
num_crops = image_input["num_crops"]
|
|
|
|
# Call the vision backbone on the whole batch at once
|
|
images_flat = flatten_bn(images, concat=True)
|
|
image_masks_flat = (None if image_masks is None else flatten_bn(
|
|
image_masks, concat=True))
|
|
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
|
|
|
|
image_features_flat = self.vision_backbone(
|
|
images=images_flat.unsqueeze(0),
|
|
image_masks=(None if image_masks_flat is None else
|
|
image_masks_flat.unsqueeze(0)),
|
|
).squeeze(0)
|
|
|
|
# Only the features corresponding to patch tokens are relevant
|
|
return [
|
|
feats[f_is_patch] for feats, f_is_patch in zip(
|
|
image_features_flat.split(num_crops.tolist()),
|
|
feat_is_patch_flat.split(num_crops.tolist()),
|
|
)
|
|
]
|
|
|
|
def get_multimodal_embeddings(
|
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
if image_input is None:
|
|
return None
|
|
|
|
image_features = self._process_image_input(image_input)
|
|
|
|
return scatter_patch_features(
|
|
image_features,
|
|
image_input["embed_is_patch"],
|
|
)
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
|
if multimodal_embeddings is not None:
|
|
assert self.img_patch_id is not None
|
|
|
|
inputs_embeds = merge_multimodal_embeddings(
|
|
input_ids,
|
|
inputs_embeds,
|
|
select_patch_features(multimodal_embeddings),
|
|
self.img_patch_id,
|
|
)
|
|
return inputs_embeds
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
positions: torch.LongTensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
**kwargs: object,
|
|
) -> SamplerOutput:
|
|
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
|
# condition is for v0 compatibility.
|
|
elif inputs_embeds is None:
|
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
|
vision_embeddings)
|
|
input_ids = None
|
|
|
|
hidden_states = self.model(input_ids,
|
|
positions,
|
|
intermediate_tensors,
|
|
inputs_embeds=inputs_embeds)
|
|
|
|
return hidden_states
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
sampling_metadata)
|
|
return logits
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
next_tokens = self.sampler(logits, sampling_metadata)
|
|
return next_tokens
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
|
|
loader = AutoWeightsLoader(self)
|
|
weights = _get_weights_with_merged_embedding(weights)
|
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
|
|
|
def get_mm_mapping(self) -> MultiModelKeys:
|
|
"""
|
|
Get the module prefix in multimodal models
|
|
"""
|
|
return MultiModelKeys.from_string_field(
|
|
language_model="model",
|
|
connector="vision_backbone.image_projector",
|
|
tower_model="vision_backbone",
|
|
)
|
|
|
|
|
|
def _get_weights_with_merged_embedding(
|
|
weights: Iterable[Tuple[str, torch.Tensor]]
|
|
) -> Iterable[Tuple[str, torch.Tensor]]:
|
|
embedding_weights = {}
|
|
for name, weight in weights:
|
|
if "wte.embedding" in name:
|
|
embedding_weights["embedding"] = weight
|
|
elif "wte.new_embedding" in name:
|
|
embedding_weights["new_embedding"] = weight
|
|
else:
|
|
yield (name, weight)
|
|
# this is compatible with most of quantization,
|
|
# because they won't quantize embed_tokens
|
|
embedding_weights = torch.cat(
|
|
[embedding_weights["embedding"], embedding_weights["new_embedding"]],
|
|
dim=0,
|
|
)
|
|
yield ("model.embed_tokens.weight", embedding_weights)
|