[Bugfix] handle alignment of encoder_seq_lens in mllama.py (#14784)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
parent
5285589f37
commit
71b9cde010
@ -209,14 +209,15 @@ def _run_test(
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=3,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
|
||||
}) as vllm_model:
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=19212, # 3 max size images
|
||||
max_num_seqs=3,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
limit_mm_per_prompt={"image":
|
||||
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
@ -507,7 +508,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
max_num_seqs=4,
|
||||
tensor_parallel_size=1,
|
||||
limit_mm_per_prompt={"image":
|
||||
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
|
||||
@ -552,6 +553,23 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
|
||||
num_logprobs,
|
||||
images=images)
|
||||
|
||||
# Mixed batch with text and images with different numbers of tiles
|
||||
prompts = [
|
||||
"<|begin_of_text|>Hello!",
|
||||
"<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501
|
||||
"<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501
|
||||
]
|
||||
images = [
|
||||
None,
|
||||
[stop_sign],
|
||||
# smaller image must be 2nd for the repro
|
||||
[stop_sign.resize((448, 448))],
|
||||
]
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
images=images)
|
||||
|
||||
|
||||
class DummyModel:
|
||||
image_token_id = MLLAMA_IMAGE_TOKEN_ID
|
||||
@ -674,3 +692,26 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
|
||||
f"full_text_row_masked_out_mask[{idx}] must be " \
|
||||
f"'{must_be_masked}' "
|
||||
idx += 1
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("encoder_seq_lens, num_tiles, expected", [
|
||||
([6404], [[4]], [6404]),
|
||||
([0, 6404], [[4]], [6404]),
|
||||
([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]),
|
||||
([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]),
|
||||
])
|
||||
def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles,
|
||||
expected) -> None:
|
||||
|
||||
dummy = DummyModel()
|
||||
num_tokens_per_tile = 1601
|
||||
actual_encoder_seq_lens = MllamaForConditionalGeneration \
|
||||
._get_and_validate_encoder_lens(
|
||||
dummy,
|
||||
encoder_seq_lens,
|
||||
num_tiles,
|
||||
num_tokens_per_tile,
|
||||
)
|
||||
assert actual_encoder_seq_lens == expected, \
|
||||
f"Expected {expected} but got {actual_encoder_seq_lens}"
|
||||
|
@ -1301,6 +1301,31 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _get_and_validate_encoder_lens(
|
||||
self,
|
||||
encoder_seq_lens: List[int],
|
||||
num_tiles: List[List[int]],
|
||||
num_tokens_per_tile: int,
|
||||
) -> List[int]:
|
||||
# Get the actual number of encoder tokens for each sample.
|
||||
# Because attn_metadata.encoder_seq_lens only counts the last
|
||||
# group of images for each sample, which is used to cheat the
|
||||
# block manager to allocate blocks for those images only.
|
||||
# See MllamaMultiModalProcessor for more details.
|
||||
actual_encoder_seq_lens = [
|
||||
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
|
||||
]
|
||||
|
||||
# remove 0 encoder len entries for text-only requests for these
|
||||
# assertions
|
||||
attn_metadata_lens = [x for x in encoder_seq_lens if x > 0]
|
||||
assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
|
||||
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
|
||||
attn_metadata_lens):
|
||||
assert actual_len >= last_group_len
|
||||
|
||||
return actual_encoder_seq_lens
|
||||
|
||||
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
actual_encoder_seq_lens: List[int]):
|
||||
@ -1428,20 +1453,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
else:
|
||||
skip_cross_attention = False
|
||||
|
||||
# Get the actual number of encoder tokens for each sample.
|
||||
# Because attn_metadata.encoder_seq_lens only counts the last
|
||||
# group of images for each sample, which is used to cheat the
|
||||
# block manager to allocate blocks for those images only.
|
||||
# See MllamaMultiModalProcessor for more details.
|
||||
num_tiles_tensor = kwargs.pop("num_tiles")
|
||||
num_tiles = [t.tolist() for t in num_tiles_tensor]
|
||||
num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")]
|
||||
num_tokens_per_tile = calc_token_per_chunk(self.image_size)
|
||||
actual_encoder_seq_lens = [
|
||||
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
|
||||
]
|
||||
for actual_len, last_group_len in zip(
|
||||
actual_encoder_seq_lens, attn_metadata.encoder_seq_lens):
|
||||
assert actual_len >= last_group_len
|
||||
|
||||
actual_encoder_seq_lens = self._get_and_validate_encoder_lens(
|
||||
attn_metadata.encoder_seq_lens,
|
||||
num_tiles,
|
||||
num_tokens_per_tile,
|
||||
)
|
||||
|
||||
cross_attention_states = self.get_cross_attention_states(
|
||||
image_inputs, attn_metadata, actual_encoder_seq_lens)
|
||||
|
Loading…
x
Reference in New Issue
Block a user