Nicolò Lucchesi d55244df31
[Model] Add SupportsMultiModal.get_language_model interface (#16007)
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-04-09 04:12:54 -07:00

1611 lines
65 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import numpy as np
import torch
import torch.nn.functional as F
import transformers.models.mllama.configuration_mllama as config_mllama
from PIL.Image import Image
from torch import nn
from transformers import BatchFeature, MllamaConfig
from transformers.modeling_outputs import (BaseModelOutput,
CausalLMOutputWithPast)
from transformers.models.mllama.image_processing_mllama import (
get_optimal_tiled_canvas)
from transformers.models.mllama.processing_mllama import (
MllamaProcessor, get_cross_attention_token_mask)
import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.selector import _Backend
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataDict, MultiModalDataItems)
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .clip import CLIPMLP
from .interfaces import SupportsMultiModal, SupportsV0Only
from .llama import LlamaDecoderLayer, LlamaMLP
from .utils import maybe_prefix
logger = init_logger(__name__)
class MllamaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: """
"""(batch_size, max_num_image, max_num_chunk, num_channel, height, width)"""
aspect_ratio_ids: torch.Tensor
"""Shape: `(batch_size, max_num_image)`"""
aspect_ratio_mask: torch.Tensor
"""Shape: `(batch_size, max_num_image, max_num_tiles)`"""
# TODO: support LlamaImageEmbeddingInputs
def calc_token_per_chunk(image_size: int) -> int:
assert image_size % 14 == 0, "chunk size should be multiple of 14"
token_per_chunk = (image_size // 14)**2 + 1
return token_per_chunk
class MllamaProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> MllamaConfig:
return self.ctx.get_hf_config(MllamaConfig)
def get_hf_processor(self, **kwargs: object) -> MllamaProcessor:
return self.ctx.get_hf_processor(MllamaProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_token_per_chunk_from_config(self) -> int:
image_size = self.get_hf_config().vision_config.image_size
return calc_token_per_chunk(image_size)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
token_per_chunk = self.get_token_per_chunk_from_config()
mm_max_tokens = vision_config.max_num_tiles * token_per_chunk
return {"image": mm_max_tokens}
def get_num_tiles_per_image(self, image_height: int,
image_width: int) -> int:
vision_config = self.get_hf_config().vision_config
max_num_tiles = vision_config.max_num_tiles
image_size = vision_config.image_size
tiled_height, tiled_width = get_optimal_tiled_canvas(
image_height,
image_width,
max_num_tiles,
tile_size=image_size,
)
num_tiles_height = tiled_height // image_size
num_tiles_width = tiled_width // image_size
return num_tiles_height * num_tiles_width
def get_image_size_with_most_features(self) -> ImageSize:
vision_config = self.get_hf_config().vision_config
image_size = vision_config.image_size
max_num_tiles = vision_config.max_num_tiles
# Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=max_num_tiles * image_size, width=image_size)
class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
):
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False,
) -> MultiModalEncDecInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes)
image_token_id = self.info.get_hf_config().image_token_index
# Check that the number of image tokens in the decoder prompt matches
# the number of images provided in mm_data
num_image_tokens = mm_inputs['prompt_token_ids'].count(image_token_id)
image_data = mm_data.get("image", [])
num_images = 1 if isinstance(image_data, Image) else len(image_data)
if num_image_tokens != num_images:
raise ValueError(
f"The number of image tokens ({num_image_tokens}) must be"
f" the same as the number of images ({num_images})")
# Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501
# P0 & P1 do cross attention with placeholder of <IMG0>
# P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2>
# Example input to encoder and decoder:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128256, 128256, ..., 128256],
# 'prompt': '<|image|><|image|>...<|image|>',
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# }
if mm_data:
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tokens for those images.
token_per_chunk = self.info.get_token_per_chunk_from_config()
num_decode_images = self._get_num_image_in_last_group(
mm_inputs["prompt_token_ids"])
num_encode_images = num_images - num_decode_images
# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
num_tiles = mm_inputs["mm_kwargs"]["num_tiles"]
decode_tiles = num_tiles[num_encode_images:num_images].sum().item()
num_tokens = decode_tiles * token_per_chunk
mm_inputs["encoder_prompt_token_ids"] = [image_token_id
] * num_tokens
mm_inputs["encoder_prompt"] = "<|image|>" * num_tokens
return mm_inputs
def _get_num_image_in_last_group(self, prompt_token_ids: List[int]) -> int:
num_images = 0
for token_id in prompt_token_ids[::-1]:
if token_id == self.info.get_hf_config().image_token_index:
num_images += 1
elif num_images > 0:
break
return num_images
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
if mm_data:
num_tiles = [
self.info.get_num_tiles_per_image(img.height, img.width)
for img in mm_data["images"]
]
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs)
processed_outputs["num_tiles"] = torch.tensor(num_tiles)
for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
processed_outputs[k] = processed_outputs[k].squeeze(0)
processed_token_ids = processed_outputs.pop("input_ids")
start_idx, end_idx = 0, processed_token_ids.size(1)
processed_prompt_text = tokenizer.decode(processed_token_ids[0])
hf_processor = self.info.get_hf_processor()
bos_token = hf_processor.bos_token
# Remove the bos_token from the start of prompt,
# because we all know there would be image_token.
if processed_prompt_text.startswith(bos_token):
start_idx += 1
# Remove the bos_token from the end of prompt,
# because text is empty in this case.
if processed_prompt_text.endswith(bos_token):
end_idx -= 1
processed_outputs[
"input_ids"] = processed_token_ids[:, start_idx:end_idx]
else:
processed_outputs = tokenizer(prompt,
add_special_tokens=False,
return_tensors="pt")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
aspect_ratio_ids=MultiModalFieldConfig.batched("image"),
aspect_ratio_mask=MultiModalFieldConfig.batched("image"),
num_tiles=MultiModalFieldConfig.batched("image"),
)
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
data = mm_data.get("image", [])
num_images = 1 if isinstance(data, Image) else len(data)
image_token_id = self.info.get_hf_config().image_token_index
return [image_token_id] * num_images
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
token_per_chunk = self.info.get_token_per_chunk_from_config()
image_token_id = self.info.get_hf_config().image_token_index
def get_replacement_mllama(item_idx):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
num_tile = self.info.get_num_tiles_per_image(
image_height=image_size.height,
image_width=image_size.width,
)
num_tokens = num_tile * token_per_chunk
return [image_token_id] * num_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_mllama,
)
]
def _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask: torch.Tensor,
num_patches: int,
target_length: int,
dtype: torch.dtype,
) -> torch.Tensor:
# Expand aspect ratio mask to target_length
batch_size, max_num_tiles = aspect_ratio_mask.shape
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1,
1).to(dtype)
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
# Mask padding patches
pad_patches = target_length - num_patches
attention_mask[:, :, -pad_patches:] = 0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask = 1 - attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length)
attention_mask = attention_mask.reshape(batch_size,
max_num_tiles * target_length, 1)
attention_mask = attention_mask @ attention_mask.transpose(
-1, -2) * torch.finfo(dtype).min
attention_mask = attention_mask.unsqueeze(1)
return attention_mask
class ColumnParallelConv2dPatch(torch.nn.Module):
"""Conv2D Patching layer with model parallelism.
Column parallel over unfolded input.
Arguments:
in_channels: Input channels.
out_channels: Output channels.
kernel_size: Size of convolution kernel.
stride (default 1): Stride for convolution.
bias (default False): Use bias in Conv2d.
Input: (bsz, in_channels, width, height)
Output: (bsz, num_tokens, out_channels)
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
bias: bool = False,
) -> None:
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
self._linear = ColumnParallelLinear(
in_channels * kernel_size[0] * kernel_size[1],
out_channels,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._unfold(x)
x = x.permute(0, 2, 1)
x, _ = self._linear(x)
return x
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
is_gated: bool = True):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.is_gated = is_gated
self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1,
self.max_num_tiles * self.hidden_size)
if is_gated:
self.gate = nn.Parameter(torch.zeros(1))
def forward(self, hidden_state: torch.Tensor,
aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
embeddings = self.embedding(aspect_ratio_ids)
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1,
self.hidden_size)
if self.is_gated:
embeddings = embeddings * self.gate.tanh()
hidden_state = hidden_state + embeddings
return hidden_state
class MllamaPrecomputedPositionEmbedding(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.num_patches = (config.image_size // config.patch_size)**2 + 1
self.hidden_size = config.hidden_size
self.scale = config.hidden_size**-0.5
self.gate = nn.Parameter(torch.zeros(1))
# position embedding
position_embedding = torch.randn(self.num_patches, self.hidden_size)
self.embedding = nn.Parameter(self.scale * position_embedding)
# tile position embedding
self.tile_embedding = nn.Embedding(
self.max_aspect_ratio_id + 1,
self.max_num_tiles * self.num_patches * self.hidden_size)
def forward(self, hidden_state: torch.Tensor,
aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
# position embeddings
gated_position_embedding = (1 - self.gate.tanh()) * self.embedding
hidden_state = hidden_state + gated_position_embedding.view(
1, 1, self.num_patches, self.hidden_size)
# precomputed tile position embeddings
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
batch_size = hidden_state.shape[0]
tile_position_embedding = tile_position_embedding.reshape(
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size)
gated_tile_position_embedding = self.gate.tanh(
) * tile_position_embedding
hidden_state = hidden_state + gated_tile_position_embedding
return hidden_state
# TODO: support other attention backends for attention in vision model
class MllamaVisionSdpaAttention(nn.Module):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
tensor_parallel_size = get_tp_group().world_size
self.embed_dim = config.hidden_size
self.num_heads = config.attention_heads
self.head_dim = config.hidden_size // config.attention_heads
self.num_local_heads = self.num_heads // tensor_parallel_size
self.q_size = self.num_local_heads * self.head_dim
self.kv_size = self.num_local_heads * self.head_dim
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_state)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(q.shape[0], q.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
k = k.view(k.shape[0], k.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
v = v.view(v.shape[0], v.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
# TODO: remove padding in image encoder
attn_output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
dropout_p=0.0)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(attn_output.shape[0],
attn_output.shape[1], -1)
output, _ = self.o_proj(attn_output)
return output
class MllamaVisionEncoderLayer(nn.Module):
def __init__(
self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
is_gated: bool = False,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.attention_heads
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(
config, quant_config=quant_config, prefix=f"{prefix}.self_attn")
self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = nn.LayerNorm(self.hidden_size,
eps=config.norm_eps)
self.post_attention_layernorm = nn.LayerNorm(self.hidden_size,
eps=config.norm_eps)
# there used to be an if else here, no code path
if is_gated:
self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4)
self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state,
attention_mask=attention_mask)
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
hidden_state = residual + gate_attn * hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
hidden_state = residual + gate_ffn * hidden_state
return hidden_state
class MllamaVisionEncoder(nn.Module):
def __init__(
self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig],
num_layers: int = 32,
is_gated: bool = False,
output_hidden_states=None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layers = nn.ModuleList([
MllamaVisionEncoderLayer(config,
quant_config=quant_config,
is_gated=is_gated,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_layers)
])
self.output_hidden_states = output_hidden_states or []
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutput]:
encoder_states = ()
for i, encoder_layer in enumerate(self.layers):
if i in self.output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)
if len(self.layers) - 1 in self.output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
return hidden_states, encoder_states
class MllamaVisionModel(nn.Module):
def __init__(
self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__()
self.image_size = config.image_size
self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.in_channels = config.num_channels
self.intermediate_layers_indices = config.intermediate_layers_indices
self.num_patches = (self.image_size // self.patch_size)**2 + 1
self.scale = config.hidden_size**-0.5
self.patch_embedding = ColumnParallelConv2dPatch(
in_channels=config.num_channels,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.class_embedding = nn.Parameter(self.scale *
torch.randn(self.hidden_size))
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
config)
self.pre_tile_positional_embedding = \
MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
self.post_tile_positional_embedding = \
MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
# layer norms
self.layernorm_pre = nn.LayerNorm(self.hidden_size)
self.layernorm_post = nn.LayerNorm(self.hidden_size)
# encoders
self.transformer = MllamaVisionEncoder(
config,
quant_config,
config.num_hidden_layers,
is_gated=False,
output_hidden_states=config.intermediate_layers_indices,
prefix=f"{prefix}.transformer",
)
self.global_transformer = MllamaVisionEncoder(
config,
quant_config,
config.num_global_layers,
is_gated=True,
prefix=f"{prefix}.global_transformer",
)
def apply_class_embedding(self,
hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1,
hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(self, pixel_values: torch.Tensor,
aspect_ratio_ids: torch.Tensor,
aspect_ratio_mask: torch.Tensor) -> torch.Tensor:
batch_size, num_concurrent_media, num_tiles, num_channels, \
height, width = pixel_values.shape
pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels,
height, width)
aspect_ratio_ids = aspect_ratio_ids.reshape(
batch_size * num_concurrent_media, -1)
# patch embedding
patch_embeds = self.patch_embedding(
pixel_values.to(self.layernorm_pre.weight.dtype))
hidden_state = patch_embeds
hidden_state = ps.get_tp_group().all_gather(hidden_state)
# tile embeddings
_, num_patches, dim = hidden_state.shape
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles, -1, dim)
hidden_state = self.pre_tile_positional_embedding(
hidden_state, aspect_ratio_ids)
# apply cls token
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media * num_tiles, num_patches, dim)
hidden_state = self.apply_class_embedding(hidden_state)
num_patches += 1
# apply position embeddings
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles, num_patches, dim)
hidden_state = self.gated_positional_embedding(hidden_state,
aspect_ratio_ids)
# apply encoder
hidden_state = self.layernorm_pre(hidden_state)
# Compute the number of tokens to pad
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
# Compute padding tuple for pad function
padding = (
0, 0, 0, num_padding_patches
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
slice_index = -num_padding_patches if num_padding_patches > 0 else None
attention_mask = aspect_ratio_mask.reshape(
batch_size * num_concurrent_media, -1)
attention_mask = _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask=attention_mask,
num_patches=self.num_patches,
target_length=hidden_state.shape[2],
dtype=self.layernorm_pre.weight.dtype,
)
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1,
dim)
output = self.transformer(
hidden_state,
attention_mask=attention_mask,
)
hidden_state, intermediate_hidden_states = output[0], output[1]
intermediate_hidden_states = torch.stack(intermediate_hidden_states,
dim=-1)
# apply global encoder
hidden_state = self.layernorm_post(hidden_state)
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim)
hidden_state = self.post_tile_positional_embedding(
hidden_state, aspect_ratio_ids)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles * (num_patches + num_padding_patches), dim)
hidden_state = self.global_transformer(
hidden_state, attention_mask=attention_mask)[0]
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim)
hidden_state = hidden_state[:, :, :slice_index]
# adding intermediate layer outputs
hidden_state = hidden_state.reshape(batch_size, num_concurrent_media,
num_tiles, num_patches, dim)
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size * num_concurrent_media, num_tiles,
num_patches + num_padding_patches, -1)
intermediate_hidden_states = intermediate_hidden_states[:, :, :
slice_index]
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, -1)
hidden_state = torch.cat([hidden_state, intermediate_hidden_states],
dim=-1)
return hidden_state
class MllamaTextRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MllamaTextRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class MllamaTextCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: Optional[config_mllama.MllamaTextConfig] = None,
layer_idx: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.pipeline_parallel_rank = get_pp_group().rank_in_group
self.tensor_parallel_size = get_tp_group().world_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_local_heads = self.num_heads // self.tensor_parallel_size
self.num_local_key_value_heads = \
self.num_key_value_heads // self.tensor_parallel_size
self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.layer_idx = layer_idx
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_local_size = self.num_local_heads * self.head_dim
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
self.qkv_proj = QKVCrossParallelLinear(
self.hidden_size,
self.head_dim,
self.num_heads,
self.num_key_value_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead
self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.scaling = self.head_dim**-0.5
self.attn = Attention(
self.num_local_heads,
self.head_dim,
self.scaling,
self.num_local_key_value_heads,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_DECODER,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor],
) -> torch.Tensor:
q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
if cross_attention_states is not None:
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
k = self.k_norm(k)
q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q)
if attention_mask is not None:
output = self._attention_with_mask(q, k, v, attention_mask,
kv_range_for_decode)
else:
output = self.attn(
q.view(-1, self.num_local_heads * self.head_dim), k, v)
out, _ = self.o_proj(output)
return out
def _attention_with_mask(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: torch.Tensor,
kv_range_for_decode: List[Tuple[int, int]],
) -> torch.Tensor:
kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
# Skip writing kv-cache for the initial profiling run.
# TODO (NickLucche) replace with custom attn bias and use standard attn
if len(kv_cache.shape) > 1:
i = torch.ones(1, dtype=torch.float32)
if self.attn.backend in (_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
torch.ops._C_cache_ops.reshape_and_cache_flash(
cached_k,
cached_v,
kv_cache[0],
kv_cache[1],
attn_metadata.
cross_slot_mapping, # type: ignore[union-attr]
"auto",
i,
i,
)
elif self.attn.backend in (_Backend.XFORMERS, _Backend.ROCM_FLASH,
_Backend.TORCH_SDPA):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", i, i)
else:
raise ValueError(
f"Unsupported Attention backend {self.attn.backend} "
"enum found. Expected the Attention backend to be "
"FLASH_ATTN, FLASH_ATTN_VLLM_V1, "
"XFORMERS or TORCH_SDPA.")
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which
# can be optimized by xformers.BlockDiagonalMask.
# The mask is specially calculated for supporting multi
# images and interleaved images.
q_len = q.shape[0]
kv_len = k.shape[0]
q = q.transpose(0, 1).view(self.num_local_key_value_heads,
self.num_key_value_groups, q_len,
self.head_dim).contiguous()
k = k.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len,
self.head_dim).contiguous()
v = v.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len,
self.head_dim).contiguous()
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
is_causal=False)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output
class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention
and feedforward."""
def __init__(
self,
config: config_mllama.MllamaTextConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.cross_attn = MllamaTextCrossAttention(
config=config,
layer_idx=layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.cross_attn",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))
self.mlp = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1))
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
cross_attention_mask: torch.Tensor,
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
cross_attention_states=cross_attention_states,
)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_attn_gate.tanh(
) * hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_mlp_gate.tanh(
) * hidden_states
return hidden_states
class MllamaTextModel(nn.Module):
config_class = config_mllama.MllamaTextConfig
base_model_prefix = "model"
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config.text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
config.hidden_size)
self.cross_attention_layers = config.cross_attention_layers
layers = []
for layer_idx in range(config.num_hidden_layers):
if layer_idx in self.cross_attention_layers:
layers.append(
MllamaCrossAttentionDecoderLayer(
config,
layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
))
else:
# TODO: force LlamaDecoderLayer to config.attention_bias=False
layers.append(
LlamaDecoderLayer(
config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
))
self.layers = nn.ModuleList(layers)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]],
skip_cross_attention: bool,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
if idx in self.cross_attention_layers:
if not skip_cross_attention:
hidden_states = decoder_layer(
hidden_states=hidden_states,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=
full_text_row_masked_out_mask,
)
else:
hidden_states, residual = decoder_layer(
positions=positions,
hidden_states=hidden_states,
residual=None,
)
hidden_states = hidden_states + residual
hidden_states = self.norm(hidden_states)
return hidden_states
class MllamaForCausalLM(nn.Module):
config_class = config_mllama.MllamaTextConfig
base_model_prefix = "language_model"
_no_split_modules = [
"MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config.text_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size
self.model = MllamaTextModel(vllm_config=vllm_config,
prefix=f"{prefix}.model")
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
prefix=f"{prefix}.lm_head",
)
def forward(
self,
input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]],
skip_cross_attention: bool,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
skip_cross_attention=skip_cross_attention,
)
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor,
info=MllamaProcessingInfo,
dummy_inputs=MllamaDummyInputsBuilder)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: MllamaConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles
self.vision_output_dim = config.vision_config.vision_output_dim
self.pad_token_id = \
config.pad_token_id if config.pad_token_id is not None else -1
self.image_size = config.vision_config.image_size
self.image_token_id = config.image_token_index
self.vision_model = MllamaVisionModel(config.vision_config,
quant_config,
prefix=maybe_prefix(
prefix, "vision_model"))
self.language_model = MllamaForCausalLM(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.multi_modal_projector = ColumnParallelLinear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
bias=True,
quant_config=quant_config,
gather_output=True,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.logits_processor = LogitsProcessor(config.output_hidden_states,
config.text_config.vocab_size)
self.sampler = get_sampler()
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.language_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def unpack_data(self,
image_data: Union[List[torch.Tensor], torch.Tensor],
padding_value=0) -> torch.Tensor:
if isinstance(image_data, torch.Tensor):
# torch.Tensor
return image_data
else:
assert isinstance(
image_data[0],
torch.Tensor), "Image data is not properly batched."
# List[torch.Tensor]
bsz = len(image_data)
max_length = max(t.size(0) for t in image_data)
trailing_dims = image_data[0].shape[1:]
for data in image_data:
cur_trailing_dims = data.shape[1:]
assert cur_trailing_dims == trailing_dims
output_tensor = torch.full((bsz, max_length, *trailing_dims),
padding_value,
dtype=image_data[0].dtype,
device=image_data[0].device)
for i, t in enumerate(image_data):
output_tensor[i, :t.size(0)] = t
return output_tensor
def _parse_and_validate_image_input(self, **kwargs: object):
# tensor with the same shape will be batched together by
# MultiModalKwargs.batch, so pixel_values here can be:
# - List[torch.Tensor]:
# with shape (num_image, num_tiles, 3, image_res, image_res)
# - torch.Tensor:
# with shape (bs, num_image, num_tiles, 3, image_res, image_res)
pixel_values: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"pixel_values", None)
image_embeds: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"image_embeds", None)
aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"aspect_ratio_ids", None)
aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"aspect_ratio_mask", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None and image_embeds is not None:
raise ValueError(
"Both pixel values and image embeds are provided.")
if pixel_values is not None:
assert aspect_ratio_ids is not None
assert aspect_ratio_mask is not None
return MllamaImagePixelInputs(
type="pixel_values",
data=self.unpack_data(pixel_values),
aspect_ratio_ids=self.unpack_data(aspect_ratio_ids),
aspect_ratio_mask=self.unpack_data(aspect_ratio_mask))
if image_embeds is not None:
raise NotImplementedError
raise AssertionError("This line should be unreachable.")
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int]):
cross_attention_states_flat = torch.zeros(
sum(actual_encoder_seq_lens),
cross_attention_states.shape[-1],
device=cross_attention_states.device,
dtype=cross_attention_states.dtype)
start_pos = 0
for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens,
cross_attention_states):
end_pos = start_pos + seq_len
cross_attention_states_flat[
start_pos:end_pos] = vision_token_in_batch[:seq_len]
start_pos = end_pos
cross_attention_states = cross_attention_states_flat
return cross_attention_states
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_cross_attention_states(
self,
image_inputs: MllamaImagePixelInputs,
attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int],
) -> Tuple[torch.Tensor]:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data']
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
aspect_ratio_mask = image_inputs['aspect_ratio_mask']
cross_attention_states = self.vision_model(pixel_values,
aspect_ratio_ids,
aspect_ratio_mask)
cross_attention_states, _ = self.multi_modal_projector(
cross_attention_states)
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
cross_attention_states = cross_attention_states.view(
bsz, -1, image_token_dim)
cross_attention_states = self.flat_encoder_result(
cross_attention_states, attn_metadata, actual_encoder_seq_lens)
return cross_attention_states
def get_cross_attention_mask(
self,
input_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
num_tiles: List[List[int]],
num_tokens_per_tile: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
token_ids = input_ids.tolist()
start = 0
batch_token_ids = []
for seq_len in attn_metadata.seq_lens:
batch_token_ids.append(token_ids[start:start + seq_len])
start += seq_len
sparse_mask = [
get_cross_attention_token_mask(t, self.image_token_id)
for t in batch_token_ids
]
# Skip generating cross-attention mask if all samples
# are text-only or have only 1 leading image.
if skip_attention_mask(sparse_mask):
return None, None
dense_mask, tile_range_for_decode = \
convert_sparse_cross_attention_mask_to_dense(
sparse_mask, num_tiles, attn_metadata.seq_lens)
cross_attention_mask = \
convert_dense_cross_attention_mask_to_tensor(
dense_mask, num_tokens_per_tile, input_ids.device, dtype)
kv_range_for_decode = [[
t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile
] for t in tile_range_for_decode]
return cross_attention_mask, kv_range_for_decode
def get_full_text_row_masked_out_mask(
self,
attn_metadata: AttentionMetadata,
device: torch.device,
) -> torch.Tensor:
full_text_row_masked_out_mask = torch.ones(
(attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
start_pos = 0
for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens,
attn_metadata.encoder_seq_lens):
if encoder_seq_len == 0:
full_text_row_masked_out_mask[start_pos:start_pos +
seq_len] = False
start_pos += seq_len
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
device)
return full_text_row_masked_out_mask
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
**kwargs: object,
) -> Union[Tuple, CausalLMOutputWithPast]:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefill_tokens > 0 and \
attn_metadata.num_decode_tokens > 0:
raise ValueError("Chunk prefill not supported")
image_inputs = self._parse_and_validate_image_input(**kwargs)
cross_attention_states = None
cross_attention_mask = None
kv_range_for_decode = None
# For 1) text-only prefill and decode, 2) image-present decode.
if image_inputs is None:
full_text_row_masked_out_mask = (
attn_metadata.encoder_seq_lens_tensor
!= 0).reshape(-1, 1).to(input_ids.device)
skip_cross_attention = attn_metadata.max_encoder_seq_len == 0
# For image-present prefill.
else:
skip_cross_attention = False
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles_tensor = kwargs.pop("num_tiles")
num_tiles = [t.tolist() for t in num_tiles_tensor]
num_tokens_per_tile = calc_token_per_chunk(self.image_size)
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
for actual_len, last_group_len in zip(
actual_encoder_seq_lens, attn_metadata.encoder_seq_lens):
assert actual_len >= last_group_len
cross_attention_states = self.get_cross_attention_states(
image_inputs, attn_metadata, actual_encoder_seq_lens)
full_text_row_masked_out_mask = \
self.get_full_text_row_masked_out_mask(
attn_metadata, input_ids.device)
cross_attention_mask, kv_range_for_decode = \
self.get_cross_attention_mask(
input_ids, attn_metadata, num_tiles,
num_tokens_per_tile, cross_attention_states.dtype)
outputs = self.language_model(
input_ids=input_ids,
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
skip_cross_attention=skip_cross_attention,
)
return outputs
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
updated_params: Set[str] = set()
for name, loaded_weight in weights:
if 'patch_embedding.weight' in name:
name = name.replace('patch_embedding.weight',
'patch_embedding._linear.weight')
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
updated_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
orig_name = name
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
logger.debug("Missing name %s, orig name %s", name,
orig_name)
continue
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params
def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
for mask in sparse_mask:
# Skip text-only samples.
if len(mask) == 0:
continue
# If the sample contains more than 1 images,
# we can't skip mask.
if len(mask) != 1:
return False
# If the sample contains only 1 image,
# but the image is not the leading one,
# we can't skip mask.
if mask[0][0] != 0 or mask[0][1] != -1:
return False
return True
def convert_sparse_cross_attention_mask_to_dense(
sparse_mask: List[List[List[int]]],
num_tiles: List[List[int]],
lengths: List[int],
) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
total_length = sum(lengths)
total_tiles = sum([sum(tiles) for tiles in num_tiles])
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
# A list of ranges, range[i] = [start, end] means that the i-th image will
# use tiles[start, end] for cross-attention decoding.
tile_range_for_decode = []
seq_start = 0
tile_start = 0
# sparse_mask has an [] entry for each sequence that does not have images,
# but num_tiles does not have these entries...
num_tiles_idx = 0
for masks, length in zip(sparse_mask, lengths):
if len(masks) == 0:
# Text only
continue
tiles = num_tiles[num_tiles_idx]
num_tiles_idx += 1
ts, td = -1, 0
for mask, tile in zip(masks, tiles):
if len(mask) != 2:
continue
start, end = mask
end = min(end, length)
if end == -1:
end = length
if end == length:
if ts == -1:
ts = tile_start
td += tile
dense_mask[seq_start + start:seq_start + end,
tile_start:tile_start + tile] = 1
tile_start += tile
assert ts != -1
assert td != 0
tile_range_for_decode.append((ts, ts + td))
seq_start += length
assert num_tiles_idx == len(num_tiles)
return dense_mask, tile_range_for_decode
def convert_dense_cross_attention_mask_to_tensor(
cross_attention_token_mask: np.ndarray,
num_tokens_per_tile: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device)
mask = mask.repeat_interleave(num_tokens_per_tile, dim=1)
mask = 1.0 - mask
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min)
ninf = torch.finfo(dtype).min
full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
mask *= full_text_mask
# (num_prompt_tokens, num_encoder_tokens)
return mask