[Bugfix] Fix Mllama interleaved images input support (#15564)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Isotr0py 2025-03-30 02:11:15 +08:00 committed by GitHub
parent 2bc4be4e32
commit 3c0ff914ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 17 deletions

View File

@ -229,8 +229,8 @@ def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
placeholders = "<|image|>" * len(image_urls) img_prompt = "Given the first image <|image|> and the second image<|image|>"
prompt = f"{placeholders}<|begin_of_text|>{question}" prompt = f"<|begin_of_text|>{img_prompt}, {question}?"
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompt=prompt, prompt=prompt,

View File

@ -180,10 +180,10 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes) return_mm_hashes)
image_token_id = self.info.get_hf_config().image_token_index
# Check that the number of image tokens in the decoder prompt matches # Check that the number of image tokens in the decoder prompt matches
# the number of images provided in mm_data # the number of images provided in mm_data
num_image_tokens = mm_inputs['prompt_token_ids'].count( num_image_tokens = mm_inputs['prompt_token_ids'].count(image_token_id)
self.info.get_hf_config().image_token_index)
image_data = mm_data.get("image", []) image_data = mm_data.get("image", [])
num_images = 1 if isinstance(image_data, Image) else len(image_data) num_images = 1 if isinstance(image_data, Image) else len(image_data)
if num_image_tokens != num_images: if num_image_tokens != num_images:
@ -191,8 +191,55 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
f"The number of image tokens ({num_image_tokens}) must be" f"The number of image tokens ({num_image_tokens}) must be"
f" the same as the number of images ({num_images})") f" the same as the number of images ({num_images})")
# Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501
# P0 & P1 do cross attention with placeholder of <IMG0>
# P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2>
# Example input to encoder and decoder:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128256, 128256, ..., 128256],
# 'prompt': '<|image|><|image|>...<|image|>',
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# }
if mm_data:
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tokens for those images.
token_per_chunk = self.info.get_token_per_chunk_from_config()
num_decode_images = self._get_num_image_in_last_group(
mm_inputs["prompt_token_ids"])
num_encode_images = num_images - num_decode_images
# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
num_tiles = mm_inputs["mm_kwargs"]["num_tiles"]
decode_tiles = num_tiles[num_encode_images:num_images].sum().item()
num_tokens = decode_tiles * token_per_chunk
mm_inputs["encoder_prompt_token_ids"] = [image_token_id
] * num_tokens
mm_inputs["encoder_prompt"] = "<|image|>" * num_tokens
return mm_inputs return mm_inputs
def _get_num_image_in_last_group(self, prompt_token_ids: List[int]) -> int:
num_images = 0
for token_id in prompt_token_ids[::-1]:
if token_id == self.info.get_hf_config().image_token_index:
num_images += 1
elif num_images > 0:
break
return num_images
def _call_hf_processor( def _call_hf_processor(
self, self,
prompt: str, prompt: str,
@ -210,19 +257,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
processed_outputs["num_tiles"] = torch.tensor(num_tiles) processed_outputs["num_tiles"] = torch.tensor(num_tiles)
for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"): for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
processed_outputs[k] = processed_outputs[k].squeeze(0) processed_outputs[k] = processed_outputs[k].squeeze(0)
# Example input to encoder and decoder:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000],
# },
# }
processed_token_ids = processed_outputs.pop("input_ids") processed_token_ids = processed_outputs.pop("input_ids")
start_idx, end_idx = 0, processed_token_ids.size(1) start_idx, end_idx = 0, processed_token_ids.size(1)
processed_prompt_text = tokenizer.decode(processed_token_ids[0]) processed_prompt_text = tokenizer.decode(processed_token_ids[0])