[Bugfix] handle alignment of encoder_seq_lens in mllama.py (#14784)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
Travis Johnson 2025-04-11 13:59:50 -06:00 committed by GitHub
parent 5285589f37
commit 71b9cde010
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 22 deletions

View File

@ -209,14 +209,15 @@ def _run_test(
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=8192,
max_num_seqs=3,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
with vllm_runner(
model,
dtype=dtype,
max_model_len=19212, # 3 max size images
max_num_seqs=3,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
limit_mm_per_prompt={"image":
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
@ -507,7 +508,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
model,
dtype=dtype,
max_model_len=8192,
max_num_seqs=2,
max_num_seqs=4,
tensor_parallel_size=1,
limit_mm_per_prompt={"image":
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
@ -552,6 +553,23 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
num_logprobs,
images=images)
# Mixed batch with text and images with different numbers of tiles
prompts = [
"<|begin_of_text|>Hello!",
"<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501
"<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501
]
images = [
None,
[stop_sign],
# smaller image must be 2nd for the repro
[stop_sign.resize((448, 448))],
]
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs,
images=images)
class DummyModel:
image_token_id = MLLAMA_IMAGE_TOKEN_ID
@ -674,3 +692,26 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
f"full_text_row_masked_out_mask[{idx}] must be " \
f"'{must_be_masked}' "
idx += 1
@pytest.mark.core_model
@pytest.mark.parametrize("encoder_seq_lens, num_tiles, expected", [
([6404], [[4]], [6404]),
([0, 6404], [[4]], [6404]),
([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]),
([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]),
])
def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles,
expected) -> None:
dummy = DummyModel()
num_tokens_per_tile = 1601
actual_encoder_seq_lens = MllamaForConditionalGeneration \
._get_and_validate_encoder_lens(
dummy,
encoder_seq_lens,
num_tiles,
num_tokens_per_tile,
)
assert actual_encoder_seq_lens == expected, \
f"Expected {expected} but got {actual_encoder_seq_lens}"

View File

@ -1301,6 +1301,31 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.")
def _get_and_validate_encoder_lens(
self,
encoder_seq_lens: List[int],
num_tiles: List[List[int]],
num_tokens_per_tile: int,
) -> List[int]:
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
# remove 0 encoder len entries for text-only requests for these
# assertions
attn_metadata_lens = [x for x in encoder_seq_lens if x > 0]
assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
attn_metadata_lens):
assert actual_len >= last_group_len
return actual_encoder_seq_lens
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int]):
@ -1428,20 +1453,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
else:
skip_cross_attention = False
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles_tensor = kwargs.pop("num_tiles")
num_tiles = [t.tolist() for t in num_tiles_tensor]
num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")]
num_tokens_per_tile = calc_token_per_chunk(self.image_size)
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
for actual_len, last_group_len in zip(
actual_encoder_seq_lens, attn_metadata.encoder_seq_lens):
assert actual_len >= last_group_len
actual_encoder_seq_lens = self._get_and_validate_encoder_lens(
attn_metadata.encoder_seq_lens,
num_tiles,
num_tokens_per_tile,
)
cross_attention_states = self.get_cross_attention_states(
image_inputs, attn_metadata, actual_encoder_seq_lens)