diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 98a73916..0493222d 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -229,8 +229,8 @@ def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData: limit_mm_per_prompt={"image": len(image_urls)}, ) - placeholders = "<|image|>" * len(image_urls) - prompt = f"{placeholders}<|begin_of_text|>{question}" + img_prompt = "Given the first image <|image|> and the second image<|image|>" + prompt = f"<|begin_of_text|>{img_prompt}, {question}?" return ModelRequestData( engine_args=engine_args, prompt=prompt, diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index ac4bdbc4..68d5298d 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -180,10 +180,10 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, return_mm_hashes) + image_token_id = self.info.get_hf_config().image_token_index # Check that the number of image tokens in the decoder prompt matches # the number of images provided in mm_data - num_image_tokens = mm_inputs['prompt_token_ids'].count( - self.info.get_hf_config().image_token_index) + num_image_tokens = mm_inputs['prompt_token_ids'].count(image_token_id) image_data = mm_data.get("image", []) num_images = 1 if isinstance(image_data, Image) else len(image_data) if num_image_tokens != num_images: @@ -191,8 +191,55 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] f"The number of image tokens ({num_image_tokens}) must be" f" the same as the number of images ({num_images})") + # Given prompt: P0 P1 P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501 + # P0 & P1 do cross attention with placeholder of + # P3 P4 D5 D6 do cross attention with placeholder of and + # Example input to encoder and decoder: + # { + # 'encoder': { + # 'type': 'token', + # 'prompt_token_ids': [128256, 128256, ..., 128256], + # 'prompt': '<|image|><|image|>...<|image|>', + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # 'decoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 + # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # } + + if mm_data: + # Since only the last group of consecutive images + # are attended by the decoded tokens, we only need to + # get the number of tokens for those images. + token_per_chunk = self.info.get_token_per_chunk_from_config() + num_decode_images = self._get_num_image_in_last_group( + mm_inputs["prompt_token_ids"]) + num_encode_images = num_images - num_decode_images + + # Set encoder prompt length based on the number of tiles. + # This tells the block manager to allocate correct number + # of slots for encoder tokens. + num_tiles = mm_inputs["mm_kwargs"]["num_tiles"] + decode_tiles = num_tiles[num_encode_images:num_images].sum().item() + num_tokens = decode_tiles * token_per_chunk + mm_inputs["encoder_prompt_token_ids"] = [image_token_id + ] * num_tokens + mm_inputs["encoder_prompt"] = "<|image|>" * num_tokens + return mm_inputs + def _get_num_image_in_last_group(self, prompt_token_ids: List[int]) -> int: + num_images = 0 + for token_id in prompt_token_ids[::-1]: + if token_id == self.info.get_hf_config().image_token_index: + num_images += 1 + elif num_images > 0: + break + return num_images + def _call_hf_processor( self, prompt: str, @@ -210,19 +257,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] processed_outputs["num_tiles"] = torch.tensor(num_tiles) for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"): processed_outputs[k] = processed_outputs[k].squeeze(0) - # Example input to encoder and decoder: - # { - # 'encoder': { - # 'type': 'token', - # 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 - # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 - # 'multi_modal_data': {'image': }, # noqa: E501 - # }, - # 'decoder': { - # 'type': 'token', - # 'prompt_token_ids': [128000], - # }, - # } + processed_token_ids = processed_outputs.pop("input_ids") start_idx, end_idx = 0, processed_token_ids.size(1) processed_prompt_text = tokenizer.decode(processed_token_ids[0])