From 71b9cde01044fa6fa7c2fdf3043dd315a9f89e65 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Fri, 11 Apr 2025 13:59:50 -0600 Subject: [PATCH] [Bugfix] handle alignment of encoder_seq_lens in mllama.py (#14784) Signed-off-by: Travis Johnson --- .../vision_language/test_mllama.py | 59 ++++++++++++++++--- vllm/model_executor/models/mllama.py | 45 ++++++++++---- 2 files changed, 82 insertions(+), 22 deletions(-) diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index a9f0de76..d94c2e88 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -209,14 +209,15 @@ def _run_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size - with vllm_runner(model, - dtype=dtype, - max_model_len=8192, - max_num_seqs=3, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT - }) as vllm_model: + with vllm_runner( + model, + dtype=dtype, + max_model_len=19212, # 3 max size images + max_num_seqs=3, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + limit_mm_per_prompt={"image": + _LIMIT_IMAGE_PER_PROMPT}) as vllm_model: vllm_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, @@ -507,7 +508,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, model, dtype=dtype, max_model_len=8192, - max_num_seqs=2, + max_num_seqs=4, tensor_parallel_size=1, limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT}) as vllm_model: @@ -552,6 +553,23 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, num_logprobs, images=images) + # Mixed batch with text and images with different numbers of tiles + prompts = [ + "<|begin_of_text|>Hello!", + "<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501 + "<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501 + ] + images = [ + None, + [stop_sign], + # smaller image must be 2nd for the repro + [stop_sign.resize((448, 448))], + ] + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs, + images=images) + class DummyModel: image_token_id = MLLAMA_IMAGE_TOKEN_ID @@ -674,3 +692,26 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None: f"full_text_row_masked_out_mask[{idx}] must be " \ f"'{must_be_masked}' " idx += 1 + + +@pytest.mark.core_model +@pytest.mark.parametrize("encoder_seq_lens, num_tiles, expected", [ + ([6404], [[4]], [6404]), + ([0, 6404], [[4]], [6404]), + ([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]), + ([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]), +]) +def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles, + expected) -> None: + + dummy = DummyModel() + num_tokens_per_tile = 1601 + actual_encoder_seq_lens = MllamaForConditionalGeneration \ + ._get_and_validate_encoder_lens( + dummy, + encoder_seq_lens, + num_tiles, + num_tokens_per_tile, + ) + assert actual_encoder_seq_lens == expected, \ + f"Expected {expected} but got {actual_encoder_seq_lens}" diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 566149c9..7bfb3ada 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1301,6 +1301,31 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, raise AssertionError("This line should be unreachable.") + def _get_and_validate_encoder_lens( + self, + encoder_seq_lens: List[int], + num_tiles: List[List[int]], + num_tokens_per_tile: int, + ) -> List[int]: + # Get the actual number of encoder tokens for each sample. + # Because attn_metadata.encoder_seq_lens only counts the last + # group of images for each sample, which is used to cheat the + # block manager to allocate blocks for those images only. + # See MllamaMultiModalProcessor for more details. + actual_encoder_seq_lens = [ + sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles + ] + + # remove 0 encoder len entries for text-only requests for these + # assertions + attn_metadata_lens = [x for x in encoder_seq_lens if x > 0] + assert len(actual_encoder_seq_lens) == len(attn_metadata_lens) + for actual_len, last_group_len in zip(actual_encoder_seq_lens, + attn_metadata_lens): + assert actual_len >= last_group_len + + return actual_encoder_seq_lens + def flat_encoder_result(self, cross_attention_states: torch.Tensor, attn_metadata: AttentionMetadata, actual_encoder_seq_lens: List[int]): @@ -1428,20 +1453,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, else: skip_cross_attention = False - # Get the actual number of encoder tokens for each sample. - # Because attn_metadata.encoder_seq_lens only counts the last - # group of images for each sample, which is used to cheat the - # block manager to allocate blocks for those images only. - # See MllamaMultiModalProcessor for more details. - num_tiles_tensor = kwargs.pop("num_tiles") - num_tiles = [t.tolist() for t in num_tiles_tensor] + num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")] num_tokens_per_tile = calc_token_per_chunk(self.image_size) - actual_encoder_seq_lens = [ - sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles - ] - for actual_len, last_group_len in zip( - actual_encoder_seq_lens, attn_metadata.encoder_seq_lens): - assert actual_len >= last_group_len + + actual_encoder_seq_lens = self._get_and_validate_encoder_lens( + attn_metadata.encoder_seq_lens, + num_tiles, + num_tokens_per_tile, + ) cross_attention_states = self.get_cross_attention_states( image_inputs, attn_metadata, actual_encoder_seq_lens)