[Model] Make llama3.2 support multiple and interleaved images (#9095)
This commit is contained in:
parent
4d31cd424b
commit
f0fe4fe86d
@ -234,12 +234,35 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
|
||||
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
|
||||
# The configuration below has been confirmed to launch on a single L40 GPU.
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=16,
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
prompt = f"<|image|><|image|><|begin_of_text|>{question}"
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
stop_token_ids=None,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"phi3_v": load_phi3v,
|
||||
"internvl_chat": load_internvl,
|
||||
"NVLM_D": load_nvlm_d,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
"qwen_vl_chat": load_qwenvl_chat,
|
||||
"mllama": load_mllama,
|
||||
}
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
from ....utils import large_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
_LIMIT_IMAGE_PER_PROMPT = 1
|
||||
_LIMIT_IMAGE_PER_PROMPT = 3
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
@ -244,8 +244,9 @@ def _run_test(
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
|
||||
max_tokens, num_logprobs) -> None:
|
||||
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
|
||||
model, sizes, dtype, max_tokens,
|
||||
num_logprobs) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
@ -257,3 +258,81 @@ def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
||||
model, dtype, max_tokens,
|
||||
num_logprobs) -> None:
|
||||
|
||||
stop_sign = image_assets[0].pil_image
|
||||
cherry_blossom = image_assets[1].pil_image
|
||||
|
||||
inputs = [(
|
||||
[
|
||||
"<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501
|
||||
"<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501
|
||||
"<|image|><|image|><|image|><|begin_of_text|>Describe 3 images.", # noqa: E501
|
||||
],
|
||||
[
|
||||
[stop_sign, cherry_blossom],
|
||||
# Images with different sizes.
|
||||
[
|
||||
stop_sign.resize((512, 512)),
|
||||
stop_sign,
|
||||
],
|
||||
[
|
||||
stop_sign,
|
||||
stop_sign.resize((512, 1536)),
|
||||
cherry_blossom.resize((512, 1024)),
|
||||
],
|
||||
])]
|
||||
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
||||
dtype, max_tokens, num_logprobs) -> None:
|
||||
|
||||
stop_sign = image_assets[0].pil_image
|
||||
cherry_blossom = image_assets[1].pil_image
|
||||
|
||||
inputs = [(
|
||||
[
|
||||
"<|begin_of_text|>The content of the image <|image|> is", # noqa: E501
|
||||
"<|begin_of_text|>Between the first image <|image|> and the second image<|image|>, " # noqa: E501
|
||||
"which is a stop sign and which is a cherry blossom?", # noqa: E501
|
||||
],
|
||||
[
|
||||
[stop_sign],
|
||||
[stop_sign, cherry_blossom],
|
||||
])]
|
||||
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
@ -18,6 +18,7 @@ from array import array
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
@ -28,9 +29,12 @@ from transformers.modeling_outputs import (BaseModelOutput,
|
||||
CausalLMOutputWithPast)
|
||||
from transformers.models.mllama.image_processing_mllama import (
|
||||
get_optimal_tiled_canvas)
|
||||
from transformers.models.mllama.processing_mllama import (
|
||||
get_cross_attention_token_mask)
|
||||
|
||||
import vllm.distributed.parallel_state as ps
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
@ -72,6 +76,16 @@ class MllamaImagePixelInputs(TypedDict):
|
||||
# TODO: support LlamaImageEmbeddingInputs
|
||||
|
||||
|
||||
def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
|
||||
num_images = 0
|
||||
for token_id in prompt_token_ids[::-1]:
|
||||
if token_id == MLLAMA_IMAGE_TOKEN_ID:
|
||||
num_images += 1
|
||||
elif num_images > 0:
|
||||
break
|
||||
return num_images
|
||||
|
||||
|
||||
def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
# move encoder_prompt to prompt
|
||||
if llm_inputs.get("prompt") is None:
|
||||
@ -91,12 +105,16 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
llm_inputs["encoder_multi_modal_data"] = {}
|
||||
return llm_inputs
|
||||
|
||||
# get num_tiles
|
||||
if isinstance(multi_modal_data['image'], Image.Image):
|
||||
multi_modal_data['image'] = [multi_modal_data['image']]
|
||||
# Since only the last group of consecutive images
|
||||
# are attended by the decoded tokens, we only need to
|
||||
# get the number of tiles for those images.
|
||||
num_decode_images = _get_num_image_in_last_group(
|
||||
llm_inputs["prompt_token_ids"])
|
||||
hf_config = ctx.model_config.hf_config
|
||||
num_tiles = 0
|
||||
for image in multi_modal_data["image"]:
|
||||
for image in multi_modal_data["image"][::-1]:
|
||||
width, height = image.size
|
||||
tile_size = hf_config.vision_config.image_size
|
||||
canvas_height, canvas_width = get_optimal_tiled_canvas(
|
||||
@ -108,8 +126,13 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
num_tiles_height = canvas_height // tile_size
|
||||
num_tiles_width = canvas_width // tile_size
|
||||
num_tiles += num_tiles_height * num_tiles_width
|
||||
num_decode_images -= 1
|
||||
if num_decode_images == 0:
|
||||
break
|
||||
|
||||
# set encoder prompt based on num_tiles
|
||||
# Set encoder prompt length based on the number of tiles.
|
||||
# This tells the block manager to allocate correct number
|
||||
# of slots for encoder tokens.
|
||||
assert hf_config.vision_config.image_size % 14 == 0, \
|
||||
"chunk size should be multiple of 14"
|
||||
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
|
||||
@ -675,6 +698,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||
cross_attention_states: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
@ -697,15 +721,71 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
q = q.view(-1, self.num_local_heads, self.head_dim)
|
||||
q = self.q_norm(q)
|
||||
|
||||
output = self.attn(q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=AttentionType.ENCODER_DECODER)
|
||||
if attention_mask is not None:
|
||||
output = self.attention_with_mask(q, k, v, kv_cache,
|
||||
attention_mask,
|
||||
kv_range_for_decode,
|
||||
attn_metadata)
|
||||
else:
|
||||
output = self.attn(q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=AttentionType.ENCODER_DECODER)
|
||||
out, _ = self.o_proj(output)
|
||||
return out
|
||||
|
||||
def attention_with_mask(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
kv_range_for_decode: List[Tuple[int, int]],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Skip writing kv-cache for the initial profiling run.
|
||||
if len(kv_cache.shape) == 3:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_local_key_value_heads, self.head_dim)
|
||||
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
||||
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
|
||||
PagedAttention.write_to_paged_cache(
|
||||
cached_k, cached_v, key_cache, value_cache,
|
||||
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
|
||||
# We have to call torch.sdpa for prefill when using a
|
||||
# custom cross-attention mask. Because the mask is not a
|
||||
# standard causal mask, neither a block diagonal mask which
|
||||
# can be optimized by xformers.BlockDiagonalMask.
|
||||
# The mask is specially calculated for supporting multi
|
||||
# images and interleaved images.
|
||||
q_len = q.shape[0]
|
||||
kv_len = k.shape[0]
|
||||
q = q.transpose(0, 1).view(self.num_local_key_value_heads,
|
||||
self.num_key_value_groups, q_len,
|
||||
self.head_dim)
|
||||
k = k.transpose(0,
|
||||
1)[:,
|
||||
None, :, :].expand(self.num_local_key_value_heads,
|
||||
self.num_key_value_groups,
|
||||
kv_len, self.head_dim)
|
||||
v = v.transpose(0,
|
||||
1)[:,
|
||||
None, :, :].expand(self.num_local_key_value_heads,
|
||||
self.num_key_value_groups,
|
||||
kv_len, self.head_dim)
|
||||
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
|
||||
output = F.scaled_dot_product_attention(q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False)
|
||||
output = output.permute(2, 0, 1, 3).reshape(
|
||||
q_len, self.num_local_heads * self.head_dim)
|
||||
return output
|
||||
|
||||
|
||||
class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
"""Cross-attention transformer block with tanh-gated attention
|
||||
@ -741,6 +821,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
cross_attention_states: torch.Tensor,
|
||||
cross_attention_mask: torch.Tensor,
|
||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||
full_text_row_masked_out_mask: torch.Tensor,
|
||||
kv_cache: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
@ -751,6 +832,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
hidden_states = self.cross_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=cross_attention_mask,
|
||||
kv_range_for_decode=kv_range_for_decode,
|
||||
cross_attention_states=cross_attention_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
@ -804,6 +886,7 @@ class MllamaTextModel(nn.Module):
|
||||
positions: Optional[torch.LongTensor],
|
||||
cross_attention_states: Optional[torch.LongTensor],
|
||||
cross_attention_mask: Optional[torch.LongTensor],
|
||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
|
||||
torch.Tensor]],
|
||||
kv_caches: List[torch.Tensor],
|
||||
@ -820,6 +903,7 @@ class MllamaTextModel(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
cross_attention_states=cross_attention_states,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
kv_range_for_decode=kv_range_for_decode,
|
||||
full_text_row_masked_out_mask=
|
||||
full_text_row_masked_out_mask,
|
||||
kv_cache=kv_caches[idx],
|
||||
@ -868,6 +952,7 @@ class MllamaForCausalLM(nn.Module):
|
||||
positions: Optional[torch.LongTensor],
|
||||
cross_attention_states: Optional[torch.LongTensor],
|
||||
cross_attention_mask: Optional[torch.LongTensor],
|
||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
|
||||
torch.Tensor]],
|
||||
kv_caches: List[torch.Tensor],
|
||||
@ -879,6 +964,7 @@ class MllamaForCausalLM(nn.Module):
|
||||
positions=positions,
|
||||
cross_attention_states=cross_attention_states,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
kv_range_for_decode=kv_range_for_decode,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
@ -1026,36 +1112,102 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata):
|
||||
attn_metadata: AttentionMetadata,
|
||||
actual_encoder_seq_lens: List[int]):
|
||||
|
||||
cross_attention_states_flat = torch.zeros(
|
||||
sum(attn_metadata.encoder_seq_lens),
|
||||
sum(actual_encoder_seq_lens),
|
||||
cross_attention_states.shape[-1],
|
||||
device=cross_attention_states.device,
|
||||
dtype=cross_attention_states.dtype)
|
||||
start_pos = 0
|
||||
for seq_len, vision_token_in_batch in zip(
|
||||
attn_metadata.encoder_seq_lens, cross_attention_states):
|
||||
for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens,
|
||||
cross_attention_states):
|
||||
end_pos = start_pos + seq_len
|
||||
cross_attention_states_flat[
|
||||
start_pos:end_pos] = vision_token_in_batch[:seq_len]
|
||||
start_pos = end_pos
|
||||
cross_attention_states = cross_attention_states_flat
|
||||
return cross_attention_states
|
||||
|
||||
def get_cross_attention_states(
|
||||
self,
|
||||
image_inputs: MllamaImagePixelInputs,
|
||||
attn_metadata: AttentionMetadata,
|
||||
actual_encoder_seq_lens: List[int],
|
||||
) -> Tuple[torch.Tensor]:
|
||||
# NOTE: llama's reference implementation runs vision model on CPU
|
||||
pixel_values = image_inputs['data']
|
||||
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
|
||||
aspect_ratio_mask = image_inputs['aspect_ratio_mask']
|
||||
cross_attention_states = self.vision_model(pixel_values,
|
||||
aspect_ratio_ids,
|
||||
aspect_ratio_mask)
|
||||
cross_attention_states = self.multi_modal_projector(
|
||||
cross_attention_states)
|
||||
|
||||
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
|
||||
cross_attention_states = cross_attention_states.view(
|
||||
bsz, -1, image_token_dim)
|
||||
|
||||
cross_attention_states = self.flat_encoder_result(
|
||||
cross_attention_states, attn_metadata, actual_encoder_seq_lens)
|
||||
|
||||
return cross_attention_states
|
||||
|
||||
def get_cross_attention_mask(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
num_tiles: List[List[int]],
|
||||
num_tokens_per_tile: int,
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
token_ids = input_ids.tolist()
|
||||
start = 0
|
||||
batch_token_ids = []
|
||||
for seq_len in attn_metadata.seq_lens:
|
||||
batch_token_ids.append(token_ids[start:start + seq_len])
|
||||
start += seq_len
|
||||
sparse_mask = [
|
||||
get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID)
|
||||
for t in batch_token_ids
|
||||
]
|
||||
|
||||
# Skip generating cross-attention mask if all samples
|
||||
# are text-only or have only 1 leading image.
|
||||
if skip_attention_mask(sparse_mask):
|
||||
return None, None
|
||||
|
||||
dense_mask, tile_range_for_decode = \
|
||||
convert_sparse_cross_attention_mask_to_dense(
|
||||
sparse_mask, num_tiles, attn_metadata.seq_lens)
|
||||
cross_attention_mask = \
|
||||
convert_dense_cross_attention_mask_to_tensor(
|
||||
dense_mask, num_tokens_per_tile, input_ids.device, dtype)
|
||||
kv_range_for_decode = [[
|
||||
t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile
|
||||
] for t in tile_range_for_decode]
|
||||
|
||||
return cross_attention_mask, kv_range_for_decode
|
||||
|
||||
def get_full_text_row_masked_out_mask(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
full_text_row_masked_out_mask = torch.ones(
|
||||
(attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
|
||||
start_pos = 0
|
||||
for seq_len, encoder_seq_len in zip(
|
||||
attn_metadata.seq_lens_tensor.cpu(),
|
||||
attn_metadata.encoder_seq_lens):
|
||||
for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens,
|
||||
attn_metadata.encoder_seq_lens):
|
||||
if encoder_seq_len == 0:
|
||||
full_text_row_masked_out_mask[start_pos:start_pos +
|
||||
seq_len] = False
|
||||
start_pos += seq_len
|
||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
|
||||
cross_attention_states.device)
|
||||
|
||||
return cross_attention_states, full_text_row_masked_out_mask
|
||||
device)
|
||||
return full_text_row_masked_out_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1069,39 +1221,54 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
attn_metadata.num_decode_tokens > 0:
|
||||
raise ValueError("Chunk prefill not supported")
|
||||
image_inputs = self._parse_and_validate_image_input(**kwargs)
|
||||
cross_attention_states = None
|
||||
cross_attention_mask = None
|
||||
kv_range_for_decode = None
|
||||
|
||||
# For 1) text-only prefill and decode, 2) image-present decode.
|
||||
if image_inputs is None:
|
||||
cross_attention_mask = None
|
||||
full_text_row_masked_out_mask = (
|
||||
attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to(
|
||||
input_ids.device)
|
||||
cross_attention_states = None
|
||||
skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0
|
||||
|
||||
# For image-present prefill.
|
||||
else:
|
||||
# NOTE: llama's reference implementation runs vision model on CPU
|
||||
pixel_values = image_inputs['data']
|
||||
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
|
||||
aspect_ratio_mask = image_inputs['aspect_ratio_mask']
|
||||
cross_attention_states = self.vision_model(pixel_values,
|
||||
aspect_ratio_ids,
|
||||
aspect_ratio_mask)
|
||||
cross_attention_states = self.multi_modal_projector(
|
||||
cross_attention_states)
|
||||
|
||||
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
|
||||
cross_attention_states = cross_attention_states.view(
|
||||
bsz, -1, image_token_dim)
|
||||
|
||||
cross_attention_states, full_text_row_masked_out_mask = \
|
||||
self.flat_encoder_result(cross_attention_states, attn_metadata)
|
||||
skip_cross_attention = False
|
||||
# TODO: support multi-image by this mask
|
||||
cross_attention_mask = None
|
||||
|
||||
# Get the actual number of encoder tokens for each sample.
|
||||
# Because attn_metadata.encoder_seq_lens only counts the last
|
||||
# group of images for each sample, which is used to cheat the
|
||||
# block manager to allocate blocks for those images only.
|
||||
# See input_processor_for_mllama() for more details.
|
||||
num_tiles_tensor = kwargs.pop("num_tiles")
|
||||
num_tiles = [t[0].tolist() for t in num_tiles_tensor]
|
||||
num_tokens_per_tile = (self.image_size // 14)**2 + 1
|
||||
actual_encoder_seq_lens = [
|
||||
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
|
||||
]
|
||||
for actual_len, last_group_len in zip(
|
||||
actual_encoder_seq_lens, attn_metadata.encoder_seq_lens):
|
||||
assert actual_len >= last_group_len
|
||||
|
||||
cross_attention_states = self.get_cross_attention_states(
|
||||
image_inputs, attn_metadata, actual_encoder_seq_lens)
|
||||
|
||||
full_text_row_masked_out_mask = \
|
||||
self.get_full_text_row_masked_out_mask(
|
||||
attn_metadata, input_ids.device)
|
||||
|
||||
cross_attention_mask, kv_range_for_decode = \
|
||||
self.get_cross_attention_mask(
|
||||
input_ids, attn_metadata, num_tiles,
|
||||
num_tokens_per_tile, cross_attention_states.dtype)
|
||||
|
||||
outputs = self.language_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
cross_attention_states=cross_attention_states,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
kv_range_for_decode=kv_range_for_decode,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
@ -1140,3 +1307,76 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
|
||||
for mask in sparse_mask:
|
||||
# Skip text-only samples.
|
||||
if len(mask) == 0:
|
||||
continue
|
||||
# If the sample contains more than 1 images,
|
||||
# we can't skip mask.
|
||||
if len(mask) != 1:
|
||||
return False
|
||||
# If the sample contains only 1 image,
|
||||
# but the image is not the leading one,
|
||||
# we can't skip mask.
|
||||
if mask[0][0] != 0 or mask[0][1] != -1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def convert_sparse_cross_attention_mask_to_dense(
|
||||
sparse_mask: List[List[List[int]]],
|
||||
num_tiles: List[List[int]],
|
||||
lengths: List[int],
|
||||
) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
|
||||
total_length = sum(lengths)
|
||||
total_tiles = sum([sum(tiles) for tiles in num_tiles])
|
||||
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
|
||||
# A list of ranges, range[i] = [start, end] means
|
||||
# if the i-th sample has N tiles in total, the tiles[start, end]
|
||||
# will be used for cross-attention decoding.
|
||||
tile_range_for_decode = []
|
||||
|
||||
seq_start = 0
|
||||
tile_start = 0
|
||||
for masks, tiles, length in zip(sparse_mask, num_tiles, lengths):
|
||||
ts, td = -1, 0
|
||||
for mask, tile in zip(masks, tiles):
|
||||
if len(mask) != 2:
|
||||
continue
|
||||
start, end = mask
|
||||
end = min(end, length)
|
||||
if end == -1:
|
||||
end = length
|
||||
if end == length:
|
||||
if ts == -1:
|
||||
ts = tile_start
|
||||
td += tile
|
||||
dense_mask[seq_start + start:seq_start + end,
|
||||
tile_start:tile_start + tile] = 1
|
||||
tile_start += tile
|
||||
tile_range_for_decode.append((ts, ts + td))
|
||||
seq_start += length
|
||||
|
||||
return dense_mask, tile_range_for_decode
|
||||
|
||||
|
||||
def convert_dense_cross_attention_mask_to_tensor(
|
||||
cross_attention_token_mask: np.ndarray,
|
||||
num_tokens_per_tile: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device)
|
||||
mask = mask.repeat_interleave(num_tokens_per_tile, dim=1)
|
||||
|
||||
mask = 1.0 - mask
|
||||
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
ninf = torch.finfo(dtype).min
|
||||
full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
|
||||
mask *= full_text_mask
|
||||
# (num_prompt_tokens, num_encoder_tokens)
|
||||
return mask
|
||||
|
Loading…
x
Reference in New Issue
Block a user