diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index bf86a72e..926ffab6 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -399,6 +399,12 @@ Text Generation - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. - - + * - :code:`MolmoForCausalLM` + - Molmo + - Image + - :code:`allenai/Molmo-7B-D-0924`, :code:`allenai/Molmo-72B-0924`, etc. + - + - ✅︎ * - :code:`NVLM_D_Model` - NVLM-D 1.0 - Image\ :sup:`E+` diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 8d6818e7..4c88dcc2 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -300,6 +300,23 @@ def run_mllama(question: str, modality: str): return llm, prompt, stop_token_ids +# Molmo +def run_molmo(question, modality): + assert modality == "image" + + model_name = "allenai/Molmo-7B-D-0924" + + llm = LLM( + model=model_name, + trust_remote_code=True, + dtype="bfloat16", + ) + + prompt = question + stop_token_ids = None + return llm, prompt, stop_token_ids + + # GLM-4v def run_glm4v(question: str, modality: str): assert modality == "image" @@ -331,6 +348,7 @@ model_example_map = { "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, "mllama": run_mllama, + "molmo": run_molmo, "glm4v": run_glm4v, } diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 1b82b454..41354dc6 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -163,6 +163,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): return "<|image|>" if model_type == "qwen2_vl": return "<|vision_start|><|image_pad|><|vision_end|>" + if model_type == "molmo": + return "" raise TypeError(f"Unknown model type: {model_type}") elif modality == "audio": diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index eaa2b93e..d6637351 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -20,4 +20,4 @@ __all__ = [ "supports_multimodal", "SupportsPP", "supports_pp", -] +] \ No newline at end of file diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py new file mode 100644 index 00000000..ccfee165 --- /dev/null +++ b/vllm/model_executor/models/molmo.py @@ -0,0 +1,1290 @@ +import logging +import math +import re +from array import array +from dataclasses import dataclass +from functools import lru_cache, partial +from typing import (Any, Iterable, List, Mapping, Optional, Tuple, TypedDict, + Union) + +import torch +from einops import rearrange +from PIL import Image +from torch import nn +from torch.nn import functional as F +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.attention import Attention, AttentionMetadata +from vllm.attention.selector import (_Backend, backend_name_to_enum, + get_global_forced_attn_backend) +from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.utils import make_layers +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs +from vllm.platforms import current_platform +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) +from vllm.transformers_utils.processor import get_processor + +log = logging.getLogger(__name__) + +# TODO: hard-coded for now. Consider making it configurable. +VIT_LAYERS = [-2, -9] +NUM_PREFIX_TOKENS = 1 +ADDITIONAL_VOCAB_SIZE = 128 + + +class MolmoImageInputs(TypedDict): + images: torch.Tensor + """Shape: + `(batch_size, num_crops, num_patch, patch_dim)` + """ + + image_input_idx: torch.Tensor + """Shape: + `(batch_size, num_crops, num_patch)` + """ + + seq_len: torch.Tensor + """Shape: + `(batch_size, )` + """ + + image_masks: Optional[torch.Tensor] + """Shape: + `(batch_size, num_crops, num_patch)` + """ + + +@dataclass +class VisionBackboneConfig: + image_default_input_size: Tuple[int, int] = (336, 336) + image_patch_size: int = 14 + image_pos_patch_size: int = 14 + image_emb_dim: int = 1024 + image_num_heads: int = 16 + image_num_key_value_heads: int = 16 + image_num_layers: int = 23 + image_mlp_dim: int = 4096 + image_mlp_activations: str = "quick_gelu" + image_num_pos: int = 577 + image_norm_eps: float = 1e-5 + + def __post_init__(self): + self.image_default_input_size = tuple( + self.image_default_input_size) # type: ignore[assignment] + + @property + def image_num_patch(self): + h, w = self.image_default_input_size + return h // self.image_patch_size, w // self.image_patch_size + + +class ViTMLP(nn.Module): + """MLP used in Vision Transformer.""" + + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.w1 = ColumnParallelLinear( + config.image_emb_dim, + config.image_mlp_dim, + bias=True, + quant_config=quant_config, + ) + # Activation function. + assert config.image_mlp_activations == "quick_gelu" + self.act = QuickGELU() + self.w2 = RowParallelLinear( + config.image_mlp_dim, + config.image_emb_dim, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.w1(x) + x = self.act(x) + x, _ = self.w2(x) + return x + + +class MultiHeadDotProductAttention(nn.Module): + """Multi-head attention used in Vision Transformer.""" + + def __init__( + self, + config: VisionBackboneConfig, + use_bias: bool = True, + nlayers: int = 1, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.hidden_size = config.image_emb_dim + self.total_num_heads = config.image_num_heads + tp_size = get_tensor_model_parallel_world_size() + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + + self.total_num_kv_heads = config.image_num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.wq = ColumnParallelLinear( + nlayers * self.hidden_size, + self.total_num_heads * self.head_dim, + bias=use_bias, + quant_config=quant_config, + ) + self.wk = ColumnParallelLinear( + nlayers * self.hidden_size, + self.total_num_kv_heads * self.head_dim, + bias=use_bias, + quant_config=quant_config, + ) + self.wv = ColumnParallelLinear( + nlayers * self.hidden_size, + self.total_num_kv_heads * self.head_dim, + bias=use_bias, + quant_config=quant_config, + ) + self.wo = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=use_bias, + quant_config=quant_config, + ) + + # Detect attention implementation. + selected_backend: Optional[_Backend] = get_global_forced_attn_backend() + if selected_backend is None: + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + # For Volta and Turing GPUs, use xformers instead. + device_available = current_platform.get_device_capability()[0] >= 8 + if device_available: + from transformers.utils import is_flash_attn_2_available + if is_flash_attn_2_available(): + self._use_flash_attn = True + else: + log.warning( + "Current Molmo implementation has a bug with " + "`vllm-flash-attn` inside vision module, so we use " + "xformers backend instead. You can run `pip install " + "flash-attn to use flash-attention backend.") + self._use_flash_attn = False + else: + self._use_flash_attn = False + else: + if selected_backend == _Backend.FLASH_ATTN: + self._use_flash_attn = True + elif selected_backend == _Backend.XFORMERS: + self._use_flash_attn = False + else: + raise RuntimeError( + f"Molmo does not support {selected_backend} backend now.") + + def forward(self, + inputs_q: torch.Tensor, + inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor: + + if inputs_kv is not None: + inputs_k = inputs_kv + inputs_v = inputs_kv + else: + inputs_k = inputs_q + inputs_v = inputs_q + + xq, _ = self.wq(inputs_q) + xk, _ = self.wk(inputs_k) + xv, _ = self.wv(inputs_v) + q_shape = xq.size()[:-1] + (self.num_heads, self.head_dim) + kv_shape = xk.size()[:-1] + (self.num_kv_heads, self.head_dim) + xq = xq.view(*q_shape) + xk = xk.view(*kv_shape) + xv = xv.view(*kv_shape) + + if self._use_flash_attn: + from flash_attn import flash_attn_func + output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False) + else: + from xformers import ops as xops + output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0) + + output = rearrange(output, "b s h d -> b s (h d)").contiguous() + output, _ = self.wo(output) + + return output + + +class ResidualAttentionBlock(nn.Module): + """Residual attention block used in Vision Transformer.""" + + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.attention = MultiHeadDotProductAttention( + config, quant_config=quant_config) + self.feed_forward = ViTMLP(config, quant_config) + self.attention_norm = nn.LayerNorm( + config.image_emb_dim, + eps=config.image_norm_eps, + ) + self.ffn_norm = nn.LayerNorm( + config.image_emb_dim, + eps=config.image_norm_eps, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attention(self.attention_norm(x)) + x = x + self.feed_forward(self.ffn_norm(x)) + return x + + +class BlockCollection(nn.Module): + """Collection of residual attention blocks used in Vision Transformer.""" + + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock(config, quant_config) + for _ in range(config.image_num_layers) + ]) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + hidden_states = [] + for r in self.resblocks: + x = r(x) + hidden_states.append(x) + return hidden_states + + +def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor: + return token.view(1, 1, -1).expand(batch_size, -1, -1) + + +class VisionTransformer(nn.Module): + """Vision Transformer used in Vision Backbone.""" + + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + scale = config.image_emb_dim**-0.5 + self.patch_num = config.image_num_patch + self.class_embedding = nn.Parameter( + torch.randn(config.image_emb_dim) * scale) + self.num_prefix_tokens: int = NUM_PREFIX_TOKENS + self.positional_embedding = nn.Parameter( + torch.randn(config.image_num_pos, config.image_emb_dim) * scale) + image_patch_size = config.image_patch_size + self.patch_embedding = nn.Linear( + image_patch_size * image_patch_size * 3, + config.image_emb_dim, + bias=False, + ) + self.pre_ln = nn.LayerNorm(config.image_emb_dim, + eps=config.image_norm_eps) + self.transformer = BlockCollection(config, quant_config) + + def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: + cls_emb = self.positional_embedding[0:1] + pos_emb = self.positional_embedding[1:] + + pos_emb = pos_emb.reshape( + (int(math.sqrt(pos_emb.shape[0])), + int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])) + + (patch_num_0, patch_num_1) = patch_num + + if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: + # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) + pos_emb = F.interpolate( + pos_emb, + size=(patch_num_0, patch_num_1), + mode="bicubic", + align_corners=False, + antialias=True, + ) + pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) + + pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) + x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], + dim=1).to(x.dtype) + return x + + def forward(self, + x: torch.Tensor, + patch_num: int = None) -> List[torch.Tensor]: + """ + : param x: (batch_size, num_patch, n_pixels) + """ + if patch_num is None: + patch_num = self.patch_num + B, N, D = x.shape + + x = self.patch_embedding(x) + + # class embeddings and positional embeddings + x = torch.cat( + [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], + dim=1) + x = self.add_pos_emb(x, patch_num) + + x = self.pre_ln(x) + + hidden_states = self.transformer(x) + return hidden_states + + +class MolmoAttention(nn.Module): + """Molmo's LLM attention.""" + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % self.tp_size == 0 + + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = config.num_key_value_heads \ + or self.total_num_heads + if self.total_num_kv_heads >= self.tp_size: + assert self.total_num_kv_heads % self.tp_size == 0 + else: + assert self.tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) + + self.tp_rank: Optional[int] = None + self.k_norm: Optional[nn.Module] = None + self.q_norm: Optional[nn.Module] = None + if config.attention_layer_norm: + self.tp_rank = get_tensor_model_parallel_rank() + self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, + eps=config.layer_norm_eps) + self.q_norm = RMSNorm(config.hidden_size, + eps=config.layer_norm_eps) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def _apply_qk_norm(self, q: torch.Tensor, + k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.q_norm is not None and self.k_norm is not None: + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class MolmoMLP(nn.Module): + """Molmo's LLM mlp.""" + + def __init__( + self, + config: PretrainedConfig, + input_dim: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size // 2 + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + input_dim or self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MolmoDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + # Attention block. + self.self_attn = MolmoAttention(config, cache_config, quant_config) + + # MLP block. + self.mlp = MolmoMLP(config, quant_config=quant_config) + + # LayerNorm + assert config.layer_norm_type == "rms" + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class MolmoDecoderNormAfterLayer(MolmoDecoderLayer): + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Self Attention + residual = hidden_states + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + hidden_states = self.input_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = hidden_states + + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = None + return hidden_states, residual + + +class MolmoVisionBackbone(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + vision_config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.vit_layers = VIT_LAYERS + self.image_num_patch = vision_config.image_num_patch + self.llm_patches_per_crop = ( + (self.image_num_patch[0] + 1) // 2, + (self.image_num_patch[1] + 1) // 2, + ) + self.image_vit = VisionTransformer(vision_config, + quant_config=quant_config) + self.num_prefix_tokens = self.image_vit.num_prefix_tokens + assert self.num_prefix_tokens in { + 0, 1 + }, "Only 0 or 1 prefix tokens are supported" + self.image_pooling_2d = MultiHeadDotProductAttention( + vision_config, + nlayers=len(self.vit_layers), + quant_config=quant_config) + self.image_projector = MolmoMLP( + config, + input_dim=vision_config.image_emb_dim, + quant_config=quant_config, + ) + + image_dim = vision_config.image_emb_dim * len(self.vit_layers) + self.pad_embed = nn.Parameter(torch.zeros((2, image_dim))) + + @property + def dtype(self) -> torch.dtype: + return self.image_vit.patch_embedding.weight.dtype + + @property + def device(self) -> torch.device: + return self.image_vit.patch_embedding.weight.device + + def encode_image(self, images: torch.Tensor) -> torch.Tensor: + """ + : param images: (batch_size, num_crops, num_patch, n_pixels) + """ + B, T, N, D = images.shape + + mask = ~torch.all( + images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True) + + images = images.view(B * T, N, D) + image_features = self.image_vit(images) + + if self.vit_layers is not None: + features = [] + for layer in self.vit_layers: + features.append(image_features[layer]) + image_features = torch.cat(features, dim=-1) + else: + image_features = image_features[-1] + + if self.num_prefix_tokens > 0: + image_features = image_features[:, 1:] + + image_features = image_features * mask + image_features = image_features.view(B, T, N, -1) + + return image_features + + def forward( + self, images: torch.Tensor, image_masks: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + + # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501 + batch_size, num_image = images.shape[:2] + images = images.to(device=self.device, dtype=self.dtype) + image_features = self.encode_image(images) + + og_dtype = image_features.dtype + assert image_masks is not None + pad_embed = self.pad_embed[:, None, None, None, :] + all_pad = image_masks == 0 + partial_pad = torch.logical_and( + image_masks < 1, + torch.logical_not(all_pad)).to(dtype=torch.float32) + all_pad = all_pad.to(dtype=torch.float32) + image_features = image_features + pad_embed[0] * torch.unsqueeze( + all_pad, -1) + image_features = image_features + pad_embed[1] * torch.unsqueeze( + partial_pad, -1) + + image_features = image_features.to(og_dtype) + + image_features = image_features.reshape( + (batch_size, num_image) + self.image_num_patch + (-1, ), ) + + if self.image_num_patch[0] % 2 == 1: + # Pad so we can still pool 2x2 patches + image_features = F.pad( + image_features, + (0, 0, 0, 1, 0, 1, 0, 0, 0, 0), + ) + + # image pooling + image_features = rearrange( + image_features, + 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c', + dh=2, + dw=2, + ) + + query = image_features.mean(-2, keepdim=True) + image_features = self.image_pooling_2d(query, image_features) + + h, w = self.llm_patches_per_crop + image_features = image_features.view(batch_size, num_image, h * w, -1) + + image_features = self.image_projector(image_features) + + # image_features: (batch_size, num_image, num_patch, d_model) + return image_features + + +class MolmoModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.embedding_size = config.embedding_size or config.vocab_size + self.embedding_size += ADDITIONAL_VOCAB_SIZE + self.embed_tokens = VocabParallelEmbedding( + self.embedding_size, + config.hidden_size, + quant_config=quant_config, + ) + + decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \ + else MolmoDecoderLayer + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer(config, cache_config, quant_config), + prefix=f"{prefix}.layers", + ) + + assert config.layer_norm_type == "rms" + self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + # Apply blocks one-by-one. + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) + return hidden_states + + +cached_get_processor = lru_cache(get_processor) + + +def get_num_patches(num_tiles: int, crop_patches: int, left_margin: int, + right_margin: int, pooling_size: int) -> int: + crop_window_patches = crop_patches - (left_margin + right_margin) + if num_tiles > 1: + left_crop_window_patches = (crop_window_patches + left_margin + + pooling_size - + 1) // pooling_size * pooling_size + middle_crop_window_patches = (crop_window_patches + pooling_size - + 1) // pooling_size * pooling_size + right_crop_window_patches = (crop_window_patches + right_margin + + pooling_size - + 1) // pooling_size * pooling_size + return left_crop_window_patches + ( + num_tiles - + 2) * middle_crop_window_patches + right_crop_window_patches + else: + single_crop_window_patches = (crop_patches + pooling_size - + 1) // pooling_size * pooling_size + return single_crop_window_patches + + +def get_tokens(tiling_h: int, tiling_w: int, crop_patches: int, + left_margin: int, right_margin: int, pooling_size: int) -> int: + h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin, + pooling_size) + w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin, + pooling_size) + per_row = w // pooling_size + 1 + joint = per_row * (h // pooling_size) + 2 + image_token_length = (crop_patches + pooling_size - 1) // pooling_size + resize = (image_token_length + 1) * image_token_length + 2 + return resize + joint + + +def get_max_tokens(max_crops: int, crop_patches: int, left_margin: int, + right_margin: int, pooling_size: int) -> int: + tilings = [] + for i in range(1, max_crops + 1): + for j in range(1, max_crops + 1): + if i * j <= max_crops: + tilings.append((i, j)) + tokens = [ + get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin, + right_margin, pooling_size) for i in range(len(tilings)) + ] + return max(tokens) + + +def get_max_molmo_image_tokens(ctx: InputContext) -> int: + processor = cached_get_processor(ctx.model_config.model, + trust_remote_code=True, + revision=ctx.model_config.code_revision) + image_processor = processor.image_processor + max_llm_image_tokens = get_max_tokens( + image_processor.max_crops, + image_processor.base_image_input_size[0] // + image_processor.image_patch_size, + image_processor.overlap_margins[0], + image_processor.overlap_margins[1], + 2, + ) + return max_llm_image_tokens + + +# NOTE: preprocessing for the image data has been included in the +# 'input_processor_for_molmo' function +def image_input_mapper_for_molmo( + ctx: InputContext, + data: object, +): + return MultiModalInputs(data) + + +def dummy_data_for_molmo(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + processor = cached_get_processor(ctx.model_config.model, + trust_remote_code=True, + revision=ctx.model_config.code_revision) + image_processor = processor.image_processor + + base_image_input_d = image_processor.image_patch_size + left_margin, right_margin = image_processor.overlap_margins + max_crops = image_processor.max_crops + + # Assume: prompt_token_ids always starts with bos_token_id followed image tokens # noqa: E501 + max_llm_image_tokens = get_max_molmo_image_tokens(ctx) + if seq_len - max_llm_image_tokens - 1 < 0: + raise RuntimeError( + f"Molmo cannot process {max_crops} crops in a prompt, " + "please increase max_model_len or reduce number of crops") + + # The vertical image has the maximum number of image tokens due to column tokens. # noqa: E501 + tiling = (max_crops, 1) + total_margin_pixels = base_image_input_d * (right_margin + left_margin) + crop_patches = image_processor.base_image_input_size[ + 0] // base_image_input_d + crop_window_patches = crop_patches - (right_margin + left_margin) + crop_window_size = crop_window_patches * base_image_input_d + + h = crop_window_size * tiling[0] + total_margin_pixels + w = crop_window_size * tiling[1] + total_margin_pixels + + dummy_image = Image.new("RGB", (w, h), color="red") + + out = processor.process("dummy prompt", dummy_image) + + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + out["input_ids"][:1 + max_llm_image_tokens]) + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - max_llm_image_tokens - 1) + dummy_seqdata = SequenceData(token_ids) + dummy_imgdata = { + "images": out["images"], + "image_input_idx": out["image_input_idx"], + } + if "image_masks" in out: + dummy_imgdata["image_masks"] = out["image_masks"] + dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long) + return dummy_seqdata, {"image": dummy_imgdata} + + +def pad_images( + max_total_crops: int, + images: torch.Tensor, + image_input_idx: torch.Tensor, + image_masks: Optional[torch.Tensor] = None, +): + n = max_total_crops - images.shape[0] + images = F.pad(images, (0, 0, 0, 0, 0, n), value=-1) + image_input_idx = F.pad(image_input_idx, (0, 0, 0, n), value=-1) + if image_masks is not None: + image_masks = F.pad(image_masks, (0, 0, 0, n), value=-1) + return images, image_input_idx, image_masks + + +def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): + prompt = llm_inputs["prompt"] + multi_modal_data = llm_inputs.get("multi_modal_data") + image = multi_modal_data.get("image") + processor = cached_get_processor(ctx.model_config.model, + trust_remote_code=True, + revision=ctx.model_config.code_revision) + + # NOTE: message formatting for raw text prompt is only applied for + # offline inference; for online inference, the prompt is always in + # instruction format and tokenized. + if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$", + prompt): + out = processor.process(prompt, image, message_format="none") + elif prompt is not None: + out = processor.process(prompt, image) + else: + out = processor.process(None, + image, + tokens=llm_inputs["prompt_token_ids"]) + + image_processor = processor.image_processor + max_total_crops = 1 + image_processor.max_crops + if image is not None: + images, image_input_idx, image_masks = pad_images( + max_total_crops, + out["images"], + out["image_input_idx"], + out.get("image_masks"), + ) + else: + base_image_input_size = image_processor.base_image_input_size + image_patch_size = image_processor.image_patch_size + image_num_patch = ( + base_image_input_size[0] // image_patch_size, + base_image_input_size[1] // image_patch_size, + ) + n_pixels = image_patch_size * image_patch_size * 3 + n_patches = image_num_patch[0] * image_num_patch[1] + + image_length_w = image_processor.image_token_length_w + image_length_h = image_processor.image_token_length_h + tokens_per_image = image_length_w * image_length_h + images = torch.full( + (max_total_crops, n_patches, n_pixels), + -1, + dtype=torch.float32, + ) + image_input_idx = torch.full( + (max_total_crops, tokens_per_image), + -1, + dtype=torch.int32, + ) + if image_processor.image_padding_mask: + image_masks = torch.full( + (max_total_crops, n_patches), + -1, + dtype=torch.float32, + ) + + image_data = dict( + images=images, + image_input_idx=image_input_idx, + ) + if image_masks is not None: + image_data["image_masks"] = image_masks + + image_data["seq_len"] = torch.tensor(len(out["input_ids"]), + dtype=torch.long) + + multi_modal_data = dict(image=image_data) + + return LLMInputs( + prompt_token_ids=out["input_ids"], + prompt=llm_inputs["prompt"], + multi_modal_data=multi_modal_data, + ) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(image_input_mapper_for_molmo) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) +@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) +class MolmoForCausalLM(nn.Module, SupportsMultiModal): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: Optional[MultiModalConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[Mapping[str, Any]] = None, + ) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + vision_config = VisionBackboneConfig() + self.vision_backbone = MolmoVisionBackbone(config, vision_config, + quant_config) + self.model = MolmoModel(config, cache_config, quant_config) + + if self.config.weight_tying: + self.lm_head = self.model.transformer.wte + else: + self.lm_head = ParallelLMHead( + config.embedding_size or config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + + self.logits_processor = LogitsProcessor(config.embedding_size + or config.vocab_size) + self.sampler = Sampler() + + def _parse_and_validate_image_input( + self, + **kwargs: object, + ) -> Optional[MolmoImageInputs]: + images = kwargs.pop("images", None) + image_masks = kwargs.pop("image_masks", None) + if images is None: + return None + + image_input_idx = kwargs.pop("image_input_idx", None) + seq_len = kwargs.pop("seq_len", None) + if image_input_idx is None: + raise ValueError("image_input_idx is required for Molmo model.") + if seq_len is None: + raise ValueError("seq_len is required for Molmo model.") + if not isinstance(seq_len, torch.Tensor): + seq_len = torch.tensor(seq_len) + + return MolmoImageInputs( + images=images, + image_input_idx=image_input_idx, + seq_len=seq_len, + image_masks=image_masks, + ) + + def _process_image_input( + self, + image_input: MolmoImageInputs, + ) -> torch.Tensor: + + image_features = self.vision_backbone( + images=image_input["images"], + image_masks=image_input["image_masks"], + ) + + return image_features + + def _merge_multimodal_embeddings( + self, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + image_input_idx: torch.Tensor, + seq_len: Union[torch.Tensor, List[torch.Tensor]], + ) -> torch.Tensor: + batch_size, num_image, num_patch = image_features.shape[:3] + assert image_input_idx.shape == (batch_size, num_image, num_patch) + + image_features = image_features.to(inputs_embeds.device) + seq_len = seq_len.to(inputs_embeds.device) + + # insert the image feature into the embedding. + image_features = image_features.view(batch_size, num_image * num_patch, + -1) + image_input_idx = image_input_idx.view(batch_size, + num_image * num_patch) + + valid = image_input_idx >= 0 + image_features = image_features * valid[:, :, None].to( + image_features.dtype) + image_features = image_features.view( + batch_size * num_image * num_patch, -1).contiguous() + + image_input_idx = image_input_idx * valid.to(image_input_idx.dtype) + offset = torch.cat( + [seq_len.new_zeros( + (1)), seq_len.cumsum(dim=0)[:-1]], dim=0)[:, None] + image_input_idx = image_input_idx + offset.to(image_input_idx.dtype) + image_input_idx = image_input_idx.flatten()[:, None] + mat = image_input_idx == torch.arange( + seq_len.sum().item(), device=inputs_embeds.device)[None, :] + mat = mat.to(image_features.dtype) + + inputs_embeds = inputs_embeds + torch.einsum('nd,nm->md', + image_features, mat) + + return inputs_embeds + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.LongTensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs: object, + ) -> SamplerOutput: + + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + inputs_embeds = self.model.embed_tokens(input_ids) + image_features = self._process_image_input(image_input) + + inputs_embeds = self._merge_multimodal_embeddings( + inputs_embeds, + image_features, + image_input["image_input_idx"], + image_input["seq_len"], + ) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + params_mapping = [ + ("model.transformer.ln_f.weight", "model.norm.weight"), + ("attn_out", "self_attn.o_proj"), + ("att_proj", "self_attn.qkv_proj"), + ("q_norm", "self_attn.q_norm"), + ("k_norm", "self_attn.k_norm"), + ("attn_norm", "input_layernorm"), + ("ff_norm", "post_attention_layernorm"), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + + embedding_weight = dict() + projector_weight = dict() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + + if "wte.embedding" in name: + embedding_weight["embedding"] = loaded_weight + continue + + if "wte.new_embedding" in name: + embedding_weight["new_embedding"] = loaded_weight + continue + + if "vision_backbone" in name: + if name.startswith("model"): + name = name[len("model."):] + if 'image_projector' in name: + if 'w1' in name: + projector_weight['gate_proj'] = loaded_weight + elif 'w3' in name: + projector_weight['up_proj'] = loaded_weight + elif 'w2' in name: + projector_weight['down_proj'] = loaded_weight + else: + raise ValueError( + f"Unexpected projector weight: {name}") + continue + else: + if "transformer.blocks" in name: + name = name.replace("transformer.blocks", "layers") + + if "ff_proj" in name: + name = name.replace("ff_proj", "mlp.gate_up_proj") + assert 'weight' in name + up_weight, gate_weight = loaded_weight.chunk(2, dim=0) + loaded_weight = torch.cat([gate_weight, up_weight], dim=0) + + elif "ff_out" in name: + if "layers" in name: + name = name.replace("ff_out", "mlp.down_proj") + else: + # lm head + name = name.replace("model.transformer.ff_out", + "lm_head") + + else: + for (param_name, weight_name) in params_mapping: + if param_name in name: + name = name.replace(param_name, weight_name) + break + + try: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + except KeyError: + raise ValueError(f"Unexpected weight: {name}") from None + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + gate_up_proj_weight = torch.cat( + [projector_weight["gate_proj"], projector_weight["up_proj"]], + dim=0) + name = "vision_backbone.image_projector.gate_up_proj.weight" + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, gate_up_proj_weight) + + down_proj_weight = projector_weight["down_proj"] + name = "vision_backbone.image_projector.down_proj.weight" + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, down_proj_weight) + + embedding_weight = torch.cat( + [embedding_weight["embedding"], embedding_weight["new_embedding"]], + dim=0) + name = "model.embed_tokens.weight" + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, embedding_weight) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 24fd5152..4a39b3fb 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1167,8 +1167,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, continue param = params_dict[name] except KeyError: - print(params_dict.keys()) - raise + raise ValueError(f"Unexpected weight: {name}") from None weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 8caaab99..b06d3d61 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -104,6 +104,7 @@ _MULTIMODAL_MODELS = { "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 "MiniCPMV": ("minicpmv", "MiniCPMV"), + "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),