[V1] Support interleaved modality items (#15605)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang 2025-03-29 06:30:09 -07:00 committed by GitHub
parent 6fa7cd3dbc
commit c67abd614f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 205 additions and 115 deletions

View File

@ -431,6 +431,7 @@ steps:
- pytest -v -s models/encoder_decoder/audio_language -m core_model
- pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model
- pytest -v -s models/decoder_only/vision_language/test_interleaved.py
- label: Multi-Modal Models Test (Extended) 1 # 48m
optional: true

View File

@ -747,30 +747,27 @@ class VllmRunner:
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> list[TextPrompt]:
if images is not None:
assert len(prompts) == len(images)
if videos is not None:
assert len(prompts) == len(videos)
if any(x is not None and len(x) != len(prompts)
for x in [images, videos, audios]):
raise ValueError(
"All non-None multimodal inputs must have the same length as "
"prompts")
if audios is not None:
assert len(prompts) == len(audios)
inputs = []
for i, prompt in enumerate(prompts):
multi_modal_data = {}
if images is not None and (image := images[i]) is not None:
multi_modal_data["image"] = image
if videos is not None and (video := videos[i]) is not None:
multi_modal_data["video"] = video
if audios is not None and (audio := audios[i]) is not None:
multi_modal_data["audio"] = audio
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
if image is not None:
inputs[i]["multi_modal_data"] = {"image": image}
if videos is not None:
for i, video in enumerate(videos):
if video is not None:
inputs[i]["multi_modal_data"] = {"video": video}
if audios is not None:
for i, audio in enumerate(audios):
if audio is not None:
inputs[i]["multi_modal_data"] = {"audio": audio}
inputs.append(
TextPrompt(prompt=prompt,
multi_modal_data=multi_modal_data
if multi_modal_data else None))
return inputs

View File

@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
models = ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]
def base_prompt(modalities_str: str) -> str:
return f"<|im_start|>user {modalities_str}\nDescribe what you see from these items.<|im_end|><|im_start|>assistant\n" # noqa: E501
INTERLEAVED_PROMPT = base_prompt("<image><video><image>\n")
NONINTERLEAVED_PROMPT = base_prompt("<image><image><video>\n")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["float16"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None:
"""
This is a simple test to check if interleaved and non-interleaved prompts
give the same result.
"""
image_cherry = ImageAsset("cherry_blossom").pil_image.convert("RGB")
image_stop = ImageAsset("stop_sign").pil_image.convert("RGB")
images = [image_cherry, image_stop]
video = VideoAsset(name="sample_demo_1.mp4", num_frames=16).np_ndarrays
inputs = [
(
[INTERLEAVED_PROMPT],
[images],
[video],
),
(
[NONINTERLEAVED_PROMPT],
[images],
[video],
),
]
with vllm_runner(model,
task="generate",
dtype=dtype,
limit_mm_per_prompt={"image": 2},
max_model_len=32768,
max_num_seqs=2,
tensor_parallel_size=1,
enforce_eager=True) as vllm_model:
vllm_outputs_per_case = [
vllm_model.generate_greedy(prompts,
max_tokens,
images=images,
videos=videos)
for prompts, images, videos in inputs
]
all_results = [output[0][1] for output in vllm_outputs_per_case]
outputs = [(total_str, total_str.find("assistant\n") + len("assistant\n"))
for total_str in all_results]
prompt_lengths = [prompt_len for _, prompt_len in outputs]
generated_strs = [
total_str[prompt_len:] for total_str, prompt_len in outputs
]
interleaved_prompt_len, noninterleaved_prompt_len = prompt_lengths
interleaved_output_str, noninterleaved_output_str = generated_strs
# The two prompts are identical except for the order of modality tokens.
assert interleaved_prompt_len == noninterleaved_prompt_len
# The two generated strings should be different because of the
# interleaved modality tokens.
assert interleaved_output_str != noninterleaved_output_str

View File

@ -155,7 +155,7 @@ def test_merge_and_sort_multimodal_metadata():
]
},
mm_hashes={"image": ["hash1", "hash2"]},
expected_modalities=["image"],
expected_modalities=["image", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=2),
@ -172,7 +172,7 @@ def test_merge_and_sort_multimodal_metadata():
]
},
mm_hashes=None,
expected_modalities=["image"],
expected_modalities=["image", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
@ -197,7 +197,7 @@ def test_merge_and_sort_multimodal_metadata():
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=["audio", "image"],
expected_modalities=["audio", "audio", "image", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
@ -223,7 +223,7 @@ def test_merge_and_sort_multimodal_metadata():
]
},
mm_hashes=None,
expected_modalities=["audio", "image"],
expected_modalities=["audio", "audio", "image", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
@ -254,7 +254,9 @@ def test_merge_and_sort_multimodal_metadata():
"audio": ["audio_hash1"],
"video": ["video_hash1", "video_hash2", "video_hash3"]
},
expected_modalities=["audio", "video", "image"],
expected_modalities=[
"audio", "video", "video", "video", "image", "image"
],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=4),
@ -300,12 +302,19 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=[],
expected_ranges=[],
expected_hashes=None,
expected_modalities=["image", "audio", "image", "audio"],
expected_ranges=[
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=8, length=2),
PlaceholderRange(offset=11, length=4),
],
expected_hashes=[
"image_hash1", "audio_hash1", "image_hash2", "audio_hash2"
],
),
# <image> <image> <video> <audio> <image>
# <image> <image> <audio> <video> <image>
TestCase(
mm_positions={
"image": [
@ -321,15 +330,54 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
]
},
mm_hashes=None,
expected_modalities=[],
expected_ranges=[],
expected_modalities=["image", "image", "audio", "video", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=8, length=5),
PlaceholderRange(offset=20, length=4),
],
expected_hashes=None,
),
# <image> <audio> <video> <image> with hashes
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=18, length=4),
],
"audio": [
PlaceholderRange(offset=6, length=2),
],
"video": [
PlaceholderRange(offset=10, length=5),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1"],
"video": ["video_hash1"],
},
expected_modalities=["image", "audio", "video", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=6, length=2),
PlaceholderRange(offset=10, length=5),
PlaceholderRange(offset=18, length=4),
],
expected_hashes=[
"image_hash1", "audio_hash1", "video_hash1", "image_hash2"
],
),
]
for case in test_cases:
with pytest.raises(ValueError) as ex_info:
merge_and_sort_multimodal_metadata(case.mm_positions,
case.mm_hashes)
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
expected_hashes) in test_cases:
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
mm_positions, mm_hashes)
assert "Interleaved mixed-modality" in str(ex_info.value)
assert modalities == expected_modalities
assert ranges == expected_ranges
assert hashes == expected_hashes

View File

@ -303,14 +303,10 @@ def merge_and_sort_multimodal_metadata(
Optionally if a MultiModalHashDict is given, same operation will be
applied to the object and the sorted list of hashes will be returned.
Raises:
ValueError: If the input prompt has interleaved placeholders from
different modalities (e.g, "<image><audio><image> Describe the
content.")
Returns:
list[str]: Sorted list of involved modalities.
list[str]: List of item modalities in order of their positions in
the input sequence.
list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
mm_positions.
Optional[list[str]]: Sorted list of all hashes from mm_hashes if
@ -324,50 +320,33 @@ def merge_and_sort_multimodal_metadata(
# For single modality, placeholder ranges and hashes are already sorted
# so we can return the list directly.
if len(modalities) == 1:
if mm_hashes is None:
return modalities, list(mm_positions[modalities[0]]), None
else:
return modalities, list(mm_positions[modalities[0]]), list(
mm_hashes[modalities[0]])
modality = modalities[0]
placeholder_list = list(mm_positions[modality])
placeholder_lists_with_modality = [(modality, mm_positions[modality])
for modality in modalities]
return [modality] * len(
placeholder_list
), placeholder_list, None if not mm_hashes else mm_hashes[modality]
if mm_hashes is None:
sorted_placeholder_lists = sorted(placeholder_lists_with_modality,
key=lambda x: x[1][0]['offset'])
sorted_hash_lists = None
else:
hashes_lists = [
mm_hashes[modality] for modality in modalities
if modality in mm_hashes
]
sorted_pairs = sorted(zip(placeholder_lists_with_modality,
hashes_lists),
key=lambda x: x[0][1][0]['offset'])
sorted_placeholder_tuple, sorted_hash_tuple = zip(*sorted_pairs)
sorted_placeholder_lists = list(sorted_placeholder_tuple)
sorted_hash_lists = list(sorted_hash_tuple)
# Create a list of (modality, placeholder, hash) tuples for all placeholders
all_items = []
for modality in modalities:
placeholder_list = list(mm_positions[modality])
hash_list: list[Optional[str]] = list(
mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [
None
] * len(placeholder_list)
sorted_modalities = [modality for modality, _ in sorted_placeholder_lists]
for placeholder, hash_value in zip(placeholder_list, hash_list):
all_items.append((modality, placeholder, hash_value))
# Flatten sorted list of lists to a single list and verify there is no
# interleaving of placeholders from different modalities.
merged_placeholders: list[PlaceholderRange] = []
for modality, placeholder_list in sorted_placeholder_lists:
if merged_placeholders and placeholder_list[0][
'offset'] < merged_placeholders[-1]['offset']:
raise ValueError(
"Interleaved mixed-modality inference is currently not "
"supported.")
merged_placeholders.extend(placeholder_list)
# Sort all items by offset
all_items.sort(key=lambda x: x[1]['offset'])
if sorted_hash_lists is not None:
merged_hashes = []
for hash_list in sorted_hash_lists:
merged_hashes.extend(hash_list)
else:
merged_hashes = None
# Split into separate lists
sorted_modalities = [item[0] for item in all_items]
merged_placeholders = [item[1] for item in all_items]
merged_hashes = [str(item[2])
for item in all_items] if mm_hashes is not None else None
return sorted_modalities, merged_placeholders, merged_hashes
@ -383,8 +362,7 @@ def group_mm_inputs_by_modality(
Returns:
list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each
inner list contains consecutive MultiModalKwargs with same modality, or
one with multimodal modalities.
inner list contains consecutive MultiModalKwargs with same modality.
"""
if not mm_inputs:
return []

View File

@ -234,22 +234,11 @@ class Processor:
if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# contains the kwargs for all items from all modalities.
# This code separates them so that there is one set of kwargs
# per item per modality.
individual_mm_inputs = [
MultiModalKwargs.from_items([item])
for modality in decoder_mm_inputs.modalities
for item in decoder_mm_inputs.get_items(modality)
]
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
# NOTE: interleaved modalities are not supported.
(
sorted_modalities,
sorted_item_modalities,
sorted_mm_positions,
sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata(
@ -257,26 +246,26 @@ class Processor:
decoder_inputs["mm_hashes"] if self.use_hash else None,
)
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
# modalities involved.
if len(sorted_modalities) > 1:
modality_order_dict = {
modality: order
for order, modality in enumerate(sorted_modalities)
}
# Sanity check to make sure each multimodal input has only one
# modality key.
for mm_input in individual_mm_inputs:
assert len(mm_input.modalities) == 1
# Sort MultiModalKwargs to match sorted_mm_positions
sorted_mm_inputs = sorted(
individual_mm_inputs,
key=lambda mm_input: modality_order_dict[list(
mm_input.modalities)[0]])
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# is a single MultiModalKwargs for all items from all modalities.
# This code flattens kwargs for individual items in a list and
# sorts them by each item's position in the input sequence if there
# are multiple modalities.
unique_modalities = set(sorted_item_modalities)
if len(unique_modalities) > 1:
sorted_mm_inputs = []
used_indices = {modality: 0 for modality in unique_modalities}
for modality in sorted_item_modalities:
items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]]
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
]))
used_indices[modality] += 1
else:
sorted_mm_inputs = individual_mm_inputs
sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
return EngineCoreRequest(
request_id=request_id,