[Model] Make llama3.2 support multiple and interleaved images (#9095)

This commit is contained in:
Xiang Xu 2024-10-14 15:24:26 -07:00 committed by GitHub
parent 4d31cd424b
commit f0fe4fe86d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 384 additions and 42 deletions

View File

@ -234,12 +234,35 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
) )
def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
# The configuration below has been confirmed to launch on a single L40 GPU.
llm = LLM(
model=model_name,
max_model_len=4096,
max_num_seqs=16,
enforce_eager=True,
limit_mm_per_prompt={"image": len(image_urls)},
)
prompt = f"<|image|><|image|><|begin_of_text|>{question}"
return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=None,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
model_example_map = { model_example_map = {
"phi3_v": load_phi3v, "phi3_v": load_phi3v,
"internvl_chat": load_internvl, "internvl_chat": load_internvl,
"NVLM_D": load_nvlm_d, "NVLM_D": load_nvlm_d,
"qwen2_vl": load_qwen2_vl, "qwen2_vl": load_qwen2_vl,
"qwen_vl_chat": load_qwenvl_chat, "qwen_vl_chat": load_qwenvl_chat,
"mllama": load_mllama,
} }

View File

@ -12,7 +12,7 @@ from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
from ....utils import large_gpu_test from ....utils import large_gpu_test
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 1 _LIMIT_IMAGE_PER_PROMPT = 3
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "stop_sign":
@ -244,8 +244,9 @@ def _run_test(
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype, def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
max_tokens, num_logprobs) -> None: model, sizes, dtype, max_tokens,
num_logprobs) -> None:
run_test( run_test(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
@ -257,3 +258,81 @@ def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
tensor_parallel_size=1, tensor_parallel_size=1,
) )
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
inputs = [(
[
"<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501
"<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501
"<|image|><|image|><|image|><|begin_of_text|>Describe 3 images.", # noqa: E501
],
[
[stop_sign, cherry_blossom],
# Images with different sizes.
[
stop_sign.resize((512, 512)),
stop_sign,
],
[
stop_sign,
stop_sign.resize((512, 1536)),
cherry_blossom.resize((512, 1024)),
],
])]
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
dtype, max_tokens, num_logprobs) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
inputs = [(
[
"<|begin_of_text|>The content of the image <|image|> is", # noqa: E501
"<|begin_of_text|>Between the first image <|image|> and the second image<|image|>, " # noqa: E501
"which is a stop sign and which is a cherry blossom?", # noqa: E501
],
[
[stop_sign],
[stop_sign, cherry_blossom],
])]
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -18,6 +18,7 @@ from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
@ -28,9 +29,12 @@ from transformers.modeling_outputs import (BaseModelOutput,
CausalLMOutputWithPast) CausalLMOutputWithPast)
from transformers.models.mllama.image_processing_mllama import ( from transformers.models.mllama.image_processing_mllama import (
get_optimal_tiled_canvas) get_optimal_tiled_canvas)
from transformers.models.mllama.processing_mllama import (
get_cross_attention_token_mask)
import vllm.distributed.parallel_state as ps import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
@ -72,6 +76,16 @@ class MllamaImagePixelInputs(TypedDict):
# TODO: support LlamaImageEmbeddingInputs # TODO: support LlamaImageEmbeddingInputs
def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
num_images = 0
for token_id in prompt_token_ids[::-1]:
if token_id == MLLAMA_IMAGE_TOKEN_ID:
num_images += 1
elif num_images > 0:
break
return num_images
def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
# move encoder_prompt to prompt # move encoder_prompt to prompt
if llm_inputs.get("prompt") is None: if llm_inputs.get("prompt") is None:
@ -91,12 +105,16 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
llm_inputs["encoder_multi_modal_data"] = {} llm_inputs["encoder_multi_modal_data"] = {}
return llm_inputs return llm_inputs
# get num_tiles
if isinstance(multi_modal_data['image'], Image.Image): if isinstance(multi_modal_data['image'], Image.Image):
multi_modal_data['image'] = [multi_modal_data['image']] multi_modal_data['image'] = [multi_modal_data['image']]
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tiles for those images.
num_decode_images = _get_num_image_in_last_group(
llm_inputs["prompt_token_ids"])
hf_config = ctx.model_config.hf_config hf_config = ctx.model_config.hf_config
num_tiles = 0 num_tiles = 0
for image in multi_modal_data["image"]: for image in multi_modal_data["image"][::-1]:
width, height = image.size width, height = image.size
tile_size = hf_config.vision_config.image_size tile_size = hf_config.vision_config.image_size
canvas_height, canvas_width = get_optimal_tiled_canvas( canvas_height, canvas_width = get_optimal_tiled_canvas(
@ -108,8 +126,13 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
num_tiles_height = canvas_height // tile_size num_tiles_height = canvas_height // tile_size
num_tiles_width = canvas_width // tile_size num_tiles_width = canvas_width // tile_size
num_tiles += num_tiles_height * num_tiles_width num_tiles += num_tiles_height * num_tiles_width
num_decode_images -= 1
if num_decode_images == 0:
break
# set encoder prompt based on num_tiles # Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
assert hf_config.vision_config.image_size % 14 == 0, \ assert hf_config.vision_config.image_size % 14 == 0, \
"chunk size should be multiple of 14" "chunk size should be multiple of 14"
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
@ -675,6 +698,7 @@ class MllamaTextCrossAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor], cross_attention_states: Optional[torch.Tensor],
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
@ -697,15 +721,71 @@ class MllamaTextCrossAttention(nn.Module):
q = q.view(-1, self.num_local_heads, self.head_dim) q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q) q = self.q_norm(q)
output = self.attn(q, if attention_mask is not None:
k, output = self.attention_with_mask(q, k, v, kv_cache,
v, attention_mask,
kv_cache, kv_range_for_decode,
attn_metadata, attn_metadata)
attn_type=AttentionType.ENCODER_DECODER) else:
output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
out, _ = self.o_proj(output) out, _ = self.o_proj(output)
return out return out
def attention_with_mask(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
attention_mask: torch.Tensor,
kv_range_for_decode: List[Tuple[int, int]],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) == 3:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which
# can be optimized by xformers.BlockDiagonalMask.
# The mask is specially calculated for supporting multi
# images and interleaved images.
q_len = q.shape[0]
kv_len = k.shape[0]
q = q.transpose(0, 1).view(self.num_local_key_value_heads,
self.num_key_value_groups, q_len,
self.head_dim)
k = k.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len, self.head_dim)
v = v.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len, self.head_dim)
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
is_causal=False)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output
class MllamaCrossAttentionDecoderLayer(torch.nn.Module): class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention """Cross-attention transformer block with tanh-gated attention
@ -741,6 +821,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor, cross_attention_states: torch.Tensor,
cross_attention_mask: torch.Tensor, cross_attention_mask: torch.Tensor,
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: torch.Tensor, full_text_row_masked_out_mask: torch.Tensor,
kv_cache: List[torch.Tensor], kv_cache: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
@ -751,6 +832,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_states = self.cross_attn( hidden_states = self.cross_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=cross_attention_mask, attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
kv_cache=kv_cache, kv_cache=kv_cache,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
@ -804,6 +886,7 @@ class MllamaTextModel(nn.Module):
positions: Optional[torch.LongTensor], positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]], torch.Tensor]],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
@ -820,6 +903,7 @@ class MllamaTextModel(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask, cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask= full_text_row_masked_out_mask=
full_text_row_masked_out_mask, full_text_row_masked_out_mask,
kv_cache=kv_caches[idx], kv_cache=kv_caches[idx],
@ -868,6 +952,7 @@ class MllamaForCausalLM(nn.Module):
positions: Optional[torch.LongTensor], positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]], torch.Tensor]],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
@ -879,6 +964,7 @@ class MllamaForCausalLM(nn.Module):
positions=positions, positions=positions,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask, cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
@ -1026,36 +1112,102 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def flat_encoder_result(self, cross_attention_states: torch.Tensor, def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata): attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int]):
cross_attention_states_flat = torch.zeros( cross_attention_states_flat = torch.zeros(
sum(attn_metadata.encoder_seq_lens), sum(actual_encoder_seq_lens),
cross_attention_states.shape[-1], cross_attention_states.shape[-1],
device=cross_attention_states.device, device=cross_attention_states.device,
dtype=cross_attention_states.dtype) dtype=cross_attention_states.dtype)
start_pos = 0 start_pos = 0
for seq_len, vision_token_in_batch in zip( for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens,
attn_metadata.encoder_seq_lens, cross_attention_states): cross_attention_states):
end_pos = start_pos + seq_len end_pos = start_pos + seq_len
cross_attention_states_flat[ cross_attention_states_flat[
start_pos:end_pos] = vision_token_in_batch[:seq_len] start_pos:end_pos] = vision_token_in_batch[:seq_len]
start_pos = end_pos start_pos = end_pos
cross_attention_states = cross_attention_states_flat cross_attention_states = cross_attention_states_flat
return cross_attention_states
def get_cross_attention_states(
self,
image_inputs: MllamaImagePixelInputs,
attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int],
) -> Tuple[torch.Tensor]:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data']
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
aspect_ratio_mask = image_inputs['aspect_ratio_mask']
cross_attention_states = self.vision_model(pixel_values,
aspect_ratio_ids,
aspect_ratio_mask)
cross_attention_states = self.multi_modal_projector(
cross_attention_states)
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
cross_attention_states = cross_attention_states.view(
bsz, -1, image_token_dim)
cross_attention_states = self.flat_encoder_result(
cross_attention_states, attn_metadata, actual_encoder_seq_lens)
return cross_attention_states
def get_cross_attention_mask(
self,
input_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
num_tiles: List[List[int]],
num_tokens_per_tile: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
token_ids = input_ids.tolist()
start = 0
batch_token_ids = []
for seq_len in attn_metadata.seq_lens:
batch_token_ids.append(token_ids[start:start + seq_len])
start += seq_len
sparse_mask = [
get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID)
for t in batch_token_ids
]
# Skip generating cross-attention mask if all samples
# are text-only or have only 1 leading image.
if skip_attention_mask(sparse_mask):
return None, None
dense_mask, tile_range_for_decode = \
convert_sparse_cross_attention_mask_to_dense(
sparse_mask, num_tiles, attn_metadata.seq_lens)
cross_attention_mask = \
convert_dense_cross_attention_mask_to_tensor(
dense_mask, num_tokens_per_tile, input_ids.device, dtype)
kv_range_for_decode = [[
t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile
] for t in tile_range_for_decode]
return cross_attention_mask, kv_range_for_decode
def get_full_text_row_masked_out_mask(
self,
attn_metadata: AttentionMetadata,
device: torch.device,
) -> torch.Tensor:
full_text_row_masked_out_mask = torch.ones( full_text_row_masked_out_mask = torch.ones(
(attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
start_pos = 0 start_pos = 0
for seq_len, encoder_seq_len in zip( for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens,
attn_metadata.seq_lens_tensor.cpu(), attn_metadata.encoder_seq_lens):
attn_metadata.encoder_seq_lens):
if encoder_seq_len == 0: if encoder_seq_len == 0:
full_text_row_masked_out_mask[start_pos:start_pos + full_text_row_masked_out_mask[start_pos:start_pos +
seq_len] = False seq_len] = False
start_pos += seq_len start_pos += seq_len
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
cross_attention_states.device) device)
return full_text_row_masked_out_mask
return cross_attention_states, full_text_row_masked_out_mask
def forward( def forward(
self, self,
@ -1069,39 +1221,54 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata.num_decode_tokens > 0: attn_metadata.num_decode_tokens > 0:
raise ValueError("Chunk prefill not supported") raise ValueError("Chunk prefill not supported")
image_inputs = self._parse_and_validate_image_input(**kwargs) image_inputs = self._parse_and_validate_image_input(**kwargs)
cross_attention_states = None
cross_attention_mask = None
kv_range_for_decode = None
# For 1) text-only prefill and decode, 2) image-present decode.
if image_inputs is None: if image_inputs is None:
cross_attention_mask = None
full_text_row_masked_out_mask = ( full_text_row_masked_out_mask = (
attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to( attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to(
input_ids.device) input_ids.device)
cross_attention_states = None
skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0 skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0
# For image-present prefill.
else: else:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data']
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
aspect_ratio_mask = image_inputs['aspect_ratio_mask']
cross_attention_states = self.vision_model(pixel_values,
aspect_ratio_ids,
aspect_ratio_mask)
cross_attention_states = self.multi_modal_projector(
cross_attention_states)
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
cross_attention_states = cross_attention_states.view(
bsz, -1, image_token_dim)
cross_attention_states, full_text_row_masked_out_mask = \
self.flat_encoder_result(cross_attention_states, attn_metadata)
skip_cross_attention = False skip_cross_attention = False
# TODO: support multi-image by this mask
cross_attention_mask = None # Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See input_processor_for_mllama() for more details.
num_tiles_tensor = kwargs.pop("num_tiles")
num_tiles = [t[0].tolist() for t in num_tiles_tensor]
num_tokens_per_tile = (self.image_size // 14)**2 + 1
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
for actual_len, last_group_len in zip(
actual_encoder_seq_lens, attn_metadata.encoder_seq_lens):
assert actual_len >= last_group_len
cross_attention_states = self.get_cross_attention_states(
image_inputs, attn_metadata, actual_encoder_seq_lens)
full_text_row_masked_out_mask = \
self.get_full_text_row_masked_out_mask(
attn_metadata, input_ids.device)
cross_attention_mask, kv_range_for_decode = \
self.get_cross_attention_mask(
input_ids, attn_metadata, num_tiles,
num_tokens_per_tile, cross_attention_states.dtype)
outputs = self.language_model( outputs = self.language_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask, cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
@ -1140,3 +1307,76 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
for mask in sparse_mask:
# Skip text-only samples.
if len(mask) == 0:
continue
# If the sample contains more than 1 images,
# we can't skip mask.
if len(mask) != 1:
return False
# If the sample contains only 1 image,
# but the image is not the leading one,
# we can't skip mask.
if mask[0][0] != 0 or mask[0][1] != -1:
return False
return True
def convert_sparse_cross_attention_mask_to_dense(
sparse_mask: List[List[List[int]]],
num_tiles: List[List[int]],
lengths: List[int],
) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
total_length = sum(lengths)
total_tiles = sum([sum(tiles) for tiles in num_tiles])
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
# A list of ranges, range[i] = [start, end] means
# if the i-th sample has N tiles in total, the tiles[start, end]
# will be used for cross-attention decoding.
tile_range_for_decode = []
seq_start = 0
tile_start = 0
for masks, tiles, length in zip(sparse_mask, num_tiles, lengths):
ts, td = -1, 0
for mask, tile in zip(masks, tiles):
if len(mask) != 2:
continue
start, end = mask
end = min(end, length)
if end == -1:
end = length
if end == length:
if ts == -1:
ts = tile_start
td += tile
dense_mask[seq_start + start:seq_start + end,
tile_start:tile_start + tile] = 1
tile_start += tile
tile_range_for_decode.append((ts, ts + td))
seq_start += length
return dense_mask, tile_range_for_decode
def convert_dense_cross_attention_mask_to_tensor(
cross_attention_token_mask: np.ndarray,
num_tokens_per_tile: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device)
mask = mask.repeat_interleave(num_tokens_per_tile, dim=1)
mask = 1.0 - mask
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min)
ninf = torch.finfo(dtype).min
full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
mask *= full_text_mask
# (num_prompt_tokens, num_encoder_tokens)
return mask