[Model] Make llama3.2 support multiple and interleaved images (#9095)

This commit is contained in:
Xiang Xu 2024-10-14 15:24:26 -07:00 committed by GitHub
parent 4d31cd424b
commit f0fe4fe86d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 384 additions and 42 deletions

View File

@ -234,12 +234,35 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
)
def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
# The configuration below has been confirmed to launch on a single L40 GPU.
llm = LLM(
model=model_name,
max_model_len=4096,
max_num_seqs=16,
enforce_eager=True,
limit_mm_per_prompt={"image": len(image_urls)},
)
prompt = f"<|image|><|image|><|begin_of_text|>{question}"
return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=None,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
model_example_map = {
"phi3_v": load_phi3v,
"internvl_chat": load_internvl,
"NVLM_D": load_nvlm_d,
"qwen2_vl": load_qwen2_vl,
"qwen_vl_chat": load_qwenvl_chat,
"mllama": load_mllama,
}

View File

@ -12,7 +12,7 @@ from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
from ....utils import large_gpu_test
from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 1
_LIMIT_IMAGE_PER_PROMPT = 3
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
@ -244,8 +244,9 @@ def _run_test(
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
max_tokens, num_logprobs) -> None:
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
model, sizes, dtype, max_tokens,
num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
@ -257,3 +258,81 @@ def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
inputs = [(
[
"<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501
"<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501
"<|image|><|image|><|image|><|begin_of_text|>Describe 3 images.", # noqa: E501
],
[
[stop_sign, cherry_blossom],
# Images with different sizes.
[
stop_sign.resize((512, 512)),
stop_sign,
],
[
stop_sign,
stop_sign.resize((512, 1536)),
cherry_blossom.resize((512, 1024)),
],
])]
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
dtype, max_tokens, num_logprobs) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
inputs = [(
[
"<|begin_of_text|>The content of the image <|image|> is", # noqa: E501
"<|begin_of_text|>Between the first image <|image|> and the second image<|image|>, " # noqa: E501
"which is a stop sign and which is a cherry blossom?", # noqa: E501
],
[
[stop_sign],
[stop_sign, cherry_blossom],
])]
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -18,6 +18,7 @@ from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
@ -28,9 +29,12 @@ from transformers.modeling_outputs import (BaseModelOutput,
CausalLMOutputWithPast)
from transformers.models.mllama.image_processing_mllama import (
get_optimal_tiled_canvas)
from transformers.models.mllama.processing_mllama import (
get_cross_attention_token_mask)
import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
@ -72,6 +76,16 @@ class MllamaImagePixelInputs(TypedDict):
# TODO: support LlamaImageEmbeddingInputs
def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
num_images = 0
for token_id in prompt_token_ids[::-1]:
if token_id == MLLAMA_IMAGE_TOKEN_ID:
num_images += 1
elif num_images > 0:
break
return num_images
def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
# move encoder_prompt to prompt
if llm_inputs.get("prompt") is None:
@ -91,12 +105,16 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
llm_inputs["encoder_multi_modal_data"] = {}
return llm_inputs
# get num_tiles
if isinstance(multi_modal_data['image'], Image.Image):
multi_modal_data['image'] = [multi_modal_data['image']]
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tiles for those images.
num_decode_images = _get_num_image_in_last_group(
llm_inputs["prompt_token_ids"])
hf_config = ctx.model_config.hf_config
num_tiles = 0
for image in multi_modal_data["image"]:
for image in multi_modal_data["image"][::-1]:
width, height = image.size
tile_size = hf_config.vision_config.image_size
canvas_height, canvas_width = get_optimal_tiled_canvas(
@ -108,8 +126,13 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
num_tiles_height = canvas_height // tile_size
num_tiles_width = canvas_width // tile_size
num_tiles += num_tiles_height * num_tiles_width
num_decode_images -= 1
if num_decode_images == 0:
break
# set encoder prompt based on num_tiles
# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
assert hf_config.vision_config.image_size % 14 == 0, \
"chunk size should be multiple of 14"
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
@ -675,6 +698,7 @@ class MllamaTextCrossAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
@ -697,6 +721,12 @@ class MllamaTextCrossAttention(nn.Module):
q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q)
if attention_mask is not None:
output = self.attention_with_mask(q, k, v, kv_cache,
attention_mask,
kv_range_for_decode,
attn_metadata)
else:
output = self.attn(q,
k,
v,
@ -706,6 +736,56 @@ class MllamaTextCrossAttention(nn.Module):
out, _ = self.o_proj(output)
return out
def attention_with_mask(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
attention_mask: torch.Tensor,
kv_range_for_decode: List[Tuple[int, int]],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) == 3:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which
# can be optimized by xformers.BlockDiagonalMask.
# The mask is specially calculated for supporting multi
# images and interleaved images.
q_len = q.shape[0]
kv_len = k.shape[0]
q = q.transpose(0, 1).view(self.num_local_key_value_heads,
self.num_key_value_groups, q_len,
self.head_dim)
k = k.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len, self.head_dim)
v = v.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len, self.head_dim)
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
is_causal=False)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output
class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention
@ -741,6 +821,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
cross_attention_mask: torch.Tensor,
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: torch.Tensor,
kv_cache: List[torch.Tensor],
attn_metadata: AttentionMetadata,
@ -751,6 +832,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_states = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
cross_attention_states=cross_attention_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
@ -804,6 +886,7 @@ class MllamaTextModel(nn.Module):
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]],
kv_caches: List[torch.Tensor],
@ -820,6 +903,7 @@ class MllamaTextModel(nn.Module):
hidden_states=hidden_states,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=
full_text_row_masked_out_mask,
kv_cache=kv_caches[idx],
@ -868,6 +952,7 @@ class MllamaForCausalLM(nn.Module):
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]],
kv_caches: List[torch.Tensor],
@ -879,6 +964,7 @@ class MllamaForCausalLM(nn.Module):
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
@ -1026,57 +1112,30 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
raise AssertionError("This line should be unreachable.")
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata):
attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int]):
cross_attention_states_flat = torch.zeros(
sum(attn_metadata.encoder_seq_lens),
sum(actual_encoder_seq_lens),
cross_attention_states.shape[-1],
device=cross_attention_states.device,
dtype=cross_attention_states.dtype)
start_pos = 0
for seq_len, vision_token_in_batch in zip(
attn_metadata.encoder_seq_lens, cross_attention_states):
for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens,
cross_attention_states):
end_pos = start_pos + seq_len
cross_attention_states_flat[
start_pos:end_pos] = vision_token_in_batch[:seq_len]
start_pos = end_pos
cross_attention_states = cross_attention_states_flat
return cross_attention_states
full_text_row_masked_out_mask = torch.ones(
(attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
start_pos = 0
for seq_len, encoder_seq_len in zip(
attn_metadata.seq_lens_tensor.cpu(),
attn_metadata.encoder_seq_lens):
if encoder_seq_len == 0:
full_text_row_masked_out_mask[start_pos:start_pos +
seq_len] = False
start_pos += seq_len
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
cross_attention_states.device)
return cross_attention_states, full_text_row_masked_out_mask
def forward(
def get_cross_attention_states(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
image_inputs: MllamaImagePixelInputs,
attn_metadata: AttentionMetadata,
**kwargs: object,
) -> Union[Tuple, CausalLMOutputWithPast]:
if attn_metadata.num_prefill_tokens > 0 and \
attn_metadata.num_decode_tokens > 0:
raise ValueError("Chunk prefill not supported")
image_inputs = self._parse_and_validate_image_input(**kwargs)
if image_inputs is None:
cross_attention_mask = None
full_text_row_masked_out_mask = (
attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to(
input_ids.device)
cross_attention_states = None
skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0
else:
actual_encoder_seq_lens: List[int],
) -> Tuple[torch.Tensor]:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data']
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
@ -1091,17 +1150,125 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
cross_attention_states = cross_attention_states.view(
bsz, -1, image_token_dim)
cross_attention_states, full_text_row_masked_out_mask = \
self.flat_encoder_result(cross_attention_states, attn_metadata)
skip_cross_attention = False
# TODO: support multi-image by this mask
cross_attention_states = self.flat_encoder_result(
cross_attention_states, attn_metadata, actual_encoder_seq_lens)
return cross_attention_states
def get_cross_attention_mask(
self,
input_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
num_tiles: List[List[int]],
num_tokens_per_tile: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
token_ids = input_ids.tolist()
start = 0
batch_token_ids = []
for seq_len in attn_metadata.seq_lens:
batch_token_ids.append(token_ids[start:start + seq_len])
start += seq_len
sparse_mask = [
get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID)
for t in batch_token_ids
]
# Skip generating cross-attention mask if all samples
# are text-only or have only 1 leading image.
if skip_attention_mask(sparse_mask):
return None, None
dense_mask, tile_range_for_decode = \
convert_sparse_cross_attention_mask_to_dense(
sparse_mask, num_tiles, attn_metadata.seq_lens)
cross_attention_mask = \
convert_dense_cross_attention_mask_to_tensor(
dense_mask, num_tokens_per_tile, input_ids.device, dtype)
kv_range_for_decode = [[
t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile
] for t in tile_range_for_decode]
return cross_attention_mask, kv_range_for_decode
def get_full_text_row_masked_out_mask(
self,
attn_metadata: AttentionMetadata,
device: torch.device,
) -> torch.Tensor:
full_text_row_masked_out_mask = torch.ones(
(attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
start_pos = 0
for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens,
attn_metadata.encoder_seq_lens):
if encoder_seq_len == 0:
full_text_row_masked_out_mask[start_pos:start_pos +
seq_len] = False
start_pos += seq_len
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
device)
return full_text_row_masked_out_mask
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs: object,
) -> Union[Tuple, CausalLMOutputWithPast]:
if attn_metadata.num_prefill_tokens > 0 and \
attn_metadata.num_decode_tokens > 0:
raise ValueError("Chunk prefill not supported")
image_inputs = self._parse_and_validate_image_input(**kwargs)
cross_attention_states = None
cross_attention_mask = None
kv_range_for_decode = None
# For 1) text-only prefill and decode, 2) image-present decode.
if image_inputs is None:
full_text_row_masked_out_mask = (
attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to(
input_ids.device)
skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0
# For image-present prefill.
else:
skip_cross_attention = False
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See input_processor_for_mllama() for more details.
num_tiles_tensor = kwargs.pop("num_tiles")
num_tiles = [t[0].tolist() for t in num_tiles_tensor]
num_tokens_per_tile = (self.image_size // 14)**2 + 1
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
for actual_len, last_group_len in zip(
actual_encoder_seq_lens, attn_metadata.encoder_seq_lens):
assert actual_len >= last_group_len
cross_attention_states = self.get_cross_attention_states(
image_inputs, attn_metadata, actual_encoder_seq_lens)
full_text_row_masked_out_mask = \
self.get_full_text_row_masked_out_mask(
attn_metadata, input_ids.device)
cross_attention_mask, kv_range_for_decode = \
self.get_cross_attention_mask(
input_ids, attn_metadata, num_tiles,
num_tokens_per_tile, cross_attention_states.dtype)
outputs = self.language_model(
input_ids=input_ids,
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
@ -1140,3 +1307,76 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
for mask in sparse_mask:
# Skip text-only samples.
if len(mask) == 0:
continue
# If the sample contains more than 1 images,
# we can't skip mask.
if len(mask) != 1:
return False
# If the sample contains only 1 image,
# but the image is not the leading one,
# we can't skip mask.
if mask[0][0] != 0 or mask[0][1] != -1:
return False
return True
def convert_sparse_cross_attention_mask_to_dense(
sparse_mask: List[List[List[int]]],
num_tiles: List[List[int]],
lengths: List[int],
) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
total_length = sum(lengths)
total_tiles = sum([sum(tiles) for tiles in num_tiles])
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
# A list of ranges, range[i] = [start, end] means
# if the i-th sample has N tiles in total, the tiles[start, end]
# will be used for cross-attention decoding.
tile_range_for_decode = []
seq_start = 0
tile_start = 0
for masks, tiles, length in zip(sparse_mask, num_tiles, lengths):
ts, td = -1, 0
for mask, tile in zip(masks, tiles):
if len(mask) != 2:
continue
start, end = mask
end = min(end, length)
if end == -1:
end = length
if end == length:
if ts == -1:
ts = tile_start
td += tile
dense_mask[seq_start + start:seq_start + end,
tile_start:tile_start + tile] = 1
tile_start += tile
tile_range_for_decode.append((ts, ts + td))
seq_start += length
return dense_mask, tile_range_for_decode
def convert_dense_cross_attention_mask_to_tensor(
cross_attention_token_mask: np.ndarray,
num_tokens_per_tile: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device)
mask = mask.repeat_interleave(num_tokens_per_tile, dim=1)
mask = 1.0 - mask
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min)
ninf = torch.finfo(dtype).min
full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
mask *= full_text_mask
# (num_prompt_tokens, num_encoder_tokens)
return mask