[Bugfix] Fix Mllama interleaved images input support (#15564)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
2bc4be4e32
commit
3c0ff914ac
@ -229,8 +229,8 @@ def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData:
|
|||||||
limit_mm_per_prompt={"image": len(image_urls)},
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
)
|
)
|
||||||
|
|
||||||
placeholders = "<|image|>" * len(image_urls)
|
img_prompt = "Given the first image <|image|> and the second image<|image|>"
|
||||||
prompt = f"{placeholders}<|begin_of_text|>{question}"
|
prompt = f"<|begin_of_text|>{img_prompt}, {question}?"
|
||||||
return ModelRequestData(
|
return ModelRequestData(
|
||||||
engine_args=engine_args,
|
engine_args=engine_args,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
@ -180,10 +180,10 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
|
|||||||
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
|
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
|
||||||
return_mm_hashes)
|
return_mm_hashes)
|
||||||
|
|
||||||
|
image_token_id = self.info.get_hf_config().image_token_index
|
||||||
# Check that the number of image tokens in the decoder prompt matches
|
# Check that the number of image tokens in the decoder prompt matches
|
||||||
# the number of images provided in mm_data
|
# the number of images provided in mm_data
|
||||||
num_image_tokens = mm_inputs['prompt_token_ids'].count(
|
num_image_tokens = mm_inputs['prompt_token_ids'].count(image_token_id)
|
||||||
self.info.get_hf_config().image_token_index)
|
|
||||||
image_data = mm_data.get("image", [])
|
image_data = mm_data.get("image", [])
|
||||||
num_images = 1 if isinstance(image_data, Image) else len(image_data)
|
num_images = 1 if isinstance(image_data, Image) else len(image_data)
|
||||||
if num_image_tokens != num_images:
|
if num_image_tokens != num_images:
|
||||||
@ -191,8 +191,55 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
|
|||||||
f"The number of image tokens ({num_image_tokens}) must be"
|
f"The number of image tokens ({num_image_tokens}) must be"
|
||||||
f" the same as the number of images ({num_images})")
|
f" the same as the number of images ({num_images})")
|
||||||
|
|
||||||
|
# Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501
|
||||||
|
# P0 & P1 do cross attention with placeholder of <IMG0>
|
||||||
|
# P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2>
|
||||||
|
# Example input to encoder and decoder:
|
||||||
|
# {
|
||||||
|
# 'encoder': {
|
||||||
|
# 'type': 'token',
|
||||||
|
# 'prompt_token_ids': [128256, 128256, ..., 128256],
|
||||||
|
# 'prompt': '<|image|><|image|>...<|image|>',
|
||||||
|
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
||||||
|
# },
|
||||||
|
# 'decoder': {
|
||||||
|
# 'type': 'token',
|
||||||
|
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
|
||||||
|
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
|
||||||
|
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
||||||
|
# },
|
||||||
|
# }
|
||||||
|
|
||||||
|
if mm_data:
|
||||||
|
# Since only the last group of consecutive images
|
||||||
|
# are attended by the decoded tokens, we only need to
|
||||||
|
# get the number of tokens for those images.
|
||||||
|
token_per_chunk = self.info.get_token_per_chunk_from_config()
|
||||||
|
num_decode_images = self._get_num_image_in_last_group(
|
||||||
|
mm_inputs["prompt_token_ids"])
|
||||||
|
num_encode_images = num_images - num_decode_images
|
||||||
|
|
||||||
|
# Set encoder prompt length based on the number of tiles.
|
||||||
|
# This tells the block manager to allocate correct number
|
||||||
|
# of slots for encoder tokens.
|
||||||
|
num_tiles = mm_inputs["mm_kwargs"]["num_tiles"]
|
||||||
|
decode_tiles = num_tiles[num_encode_images:num_images].sum().item()
|
||||||
|
num_tokens = decode_tiles * token_per_chunk
|
||||||
|
mm_inputs["encoder_prompt_token_ids"] = [image_token_id
|
||||||
|
] * num_tokens
|
||||||
|
mm_inputs["encoder_prompt"] = "<|image|>" * num_tokens
|
||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
def _get_num_image_in_last_group(self, prompt_token_ids: List[int]) -> int:
|
||||||
|
num_images = 0
|
||||||
|
for token_id in prompt_token_ids[::-1]:
|
||||||
|
if token_id == self.info.get_hf_config().image_token_index:
|
||||||
|
num_images += 1
|
||||||
|
elif num_images > 0:
|
||||||
|
break
|
||||||
|
return num_images
|
||||||
|
|
||||||
def _call_hf_processor(
|
def _call_hf_processor(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -210,19 +257,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
|
|||||||
processed_outputs["num_tiles"] = torch.tensor(num_tiles)
|
processed_outputs["num_tiles"] = torch.tensor(num_tiles)
|
||||||
for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
|
for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
|
||||||
processed_outputs[k] = processed_outputs[k].squeeze(0)
|
processed_outputs[k] = processed_outputs[k].squeeze(0)
|
||||||
# Example input to encoder and decoder:
|
|
||||||
# {
|
|
||||||
# 'encoder': {
|
|
||||||
# 'type': 'token',
|
|
||||||
# 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
|
|
||||||
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
|
|
||||||
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
|
||||||
# },
|
|
||||||
# 'decoder': {
|
|
||||||
# 'type': 'token',
|
|
||||||
# 'prompt_token_ids': [128000],
|
|
||||||
# },
|
|
||||||
# }
|
|
||||||
processed_token_ids = processed_outputs.pop("input_ids")
|
processed_token_ids = processed_outputs.pop("input_ids")
|
||||||
start_idx, end_idx = 0, processed_token_ids.size(1)
|
start_idx, end_idx = 0, processed_token_ids.size(1)
|
||||||
processed_prompt_text = tokenizer.decode(processed_token_ids[0])
|
processed_prompt_text = tokenizer.decode(processed_token_ids[0])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user