[V1] Support interleaved modality items (#15605)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang 2025-03-29 06:30:09 -07:00 committed by GitHub
parent 6fa7cd3dbc
commit c67abd614f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 205 additions and 115 deletions

View File

@ -431,6 +431,7 @@ steps:
- pytest -v -s models/encoder_decoder/audio_language -m core_model - pytest -v -s models/encoder_decoder/audio_language -m core_model
- pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model
- pytest -v -s models/decoder_only/vision_language/test_interleaved.py
- label: Multi-Modal Models Test (Extended) 1 # 48m - label: Multi-Modal Models Test (Extended) 1 # 48m
optional: true optional: true

View File

@ -747,30 +747,27 @@ class VllmRunner:
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
) -> list[TextPrompt]: ) -> list[TextPrompt]:
if images is not None:
assert len(prompts) == len(images)
if videos is not None: if any(x is not None and len(x) != len(prompts)
assert len(prompts) == len(videos) for x in [images, videos, audios]):
raise ValueError(
"All non-None multimodal inputs must have the same length as "
"prompts")
if audios is not None: inputs = []
assert len(prompts) == len(audios) for i, prompt in enumerate(prompts):
multi_modal_data = {}
if images is not None and (image := images[i]) is not None:
multi_modal_data["image"] = image
if videos is not None and (video := videos[i]) is not None:
multi_modal_data["video"] = video
if audios is not None and (audio := audios[i]) is not None:
multi_modal_data["audio"] = audio
inputs = [TextPrompt(prompt=prompt) for prompt in prompts] inputs.append(
if images is not None: TextPrompt(prompt=prompt,
for i, image in enumerate(images): multi_modal_data=multi_modal_data
if image is not None: if multi_modal_data else None))
inputs[i]["multi_modal_data"] = {"image": image}
if videos is not None:
for i, video in enumerate(videos):
if video is not None:
inputs[i]["multi_modal_data"] = {"video": video}
if audios is not None:
for i, audio in enumerate(audios):
if audio is not None:
inputs[i]["multi_modal_data"] = {"audio": audio}
return inputs return inputs

View File

@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
models = ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]
def base_prompt(modalities_str: str) -> str:
return f"<|im_start|>user {modalities_str}\nDescribe what you see from these items.<|im_end|><|im_start|>assistant\n" # noqa: E501
INTERLEAVED_PROMPT = base_prompt("<image><video><image>\n")
NONINTERLEAVED_PROMPT = base_prompt("<image><image><video>\n")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["float16"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None:
"""
This is a simple test to check if interleaved and non-interleaved prompts
give the same result.
"""
image_cherry = ImageAsset("cherry_blossom").pil_image.convert("RGB")
image_stop = ImageAsset("stop_sign").pil_image.convert("RGB")
images = [image_cherry, image_stop]
video = VideoAsset(name="sample_demo_1.mp4", num_frames=16).np_ndarrays
inputs = [
(
[INTERLEAVED_PROMPT],
[images],
[video],
),
(
[NONINTERLEAVED_PROMPT],
[images],
[video],
),
]
with vllm_runner(model,
task="generate",
dtype=dtype,
limit_mm_per_prompt={"image": 2},
max_model_len=32768,
max_num_seqs=2,
tensor_parallel_size=1,
enforce_eager=True) as vllm_model:
vllm_outputs_per_case = [
vllm_model.generate_greedy(prompts,
max_tokens,
images=images,
videos=videos)
for prompts, images, videos in inputs
]
all_results = [output[0][1] for output in vllm_outputs_per_case]
outputs = [(total_str, total_str.find("assistant\n") + len("assistant\n"))
for total_str in all_results]
prompt_lengths = [prompt_len for _, prompt_len in outputs]
generated_strs = [
total_str[prompt_len:] for total_str, prompt_len in outputs
]
interleaved_prompt_len, noninterleaved_prompt_len = prompt_lengths
interleaved_output_str, noninterleaved_output_str = generated_strs
# The two prompts are identical except for the order of modality tokens.
assert interleaved_prompt_len == noninterleaved_prompt_len
# The two generated strings should be different because of the
# interleaved modality tokens.
assert interleaved_output_str != noninterleaved_output_str

View File

@ -155,7 +155,7 @@ def test_merge_and_sort_multimodal_metadata():
] ]
}, },
mm_hashes={"image": ["hash1", "hash2"]}, mm_hashes={"image": ["hash1", "hash2"]},
expected_modalities=["image"], expected_modalities=["image", "image"],
expected_ranges=[ expected_ranges=[
PlaceholderRange(offset=0, length=2), PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=2), PlaceholderRange(offset=3, length=2),
@ -172,7 +172,7 @@ def test_merge_and_sort_multimodal_metadata():
] ]
}, },
mm_hashes=None, mm_hashes=None,
expected_modalities=["image"], expected_modalities=["image", "image"],
expected_ranges=[ expected_ranges=[
PlaceholderRange(offset=0, length=2), PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2), PlaceholderRange(offset=2, length=2),
@ -197,7 +197,7 @@ def test_merge_and_sort_multimodal_metadata():
"image": ["image_hash1", "image_hash2"], "image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"], "audio": ["audio_hash1", "audio_hash2"],
}, },
expected_modalities=["audio", "image"], expected_modalities=["audio", "audio", "image", "image"],
expected_ranges=[ expected_ranges=[
PlaceholderRange(offset=0, length=2), PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3), PlaceholderRange(offset=2, length=3),
@ -223,7 +223,7 @@ def test_merge_and_sort_multimodal_metadata():
] ]
}, },
mm_hashes=None, mm_hashes=None,
expected_modalities=["audio", "image"], expected_modalities=["audio", "audio", "image", "image"],
expected_ranges=[ expected_ranges=[
PlaceholderRange(offset=0, length=2), PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3), PlaceholderRange(offset=2, length=3),
@ -254,7 +254,9 @@ def test_merge_and_sort_multimodal_metadata():
"audio": ["audio_hash1"], "audio": ["audio_hash1"],
"video": ["video_hash1", "video_hash2", "video_hash3"] "video": ["video_hash1", "video_hash2", "video_hash3"]
}, },
expected_modalities=["audio", "video", "image"], expected_modalities=[
"audio", "video", "video", "video", "image", "image"
],
expected_ranges=[ expected_ranges=[
PlaceholderRange(offset=0, length=2), PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=4), PlaceholderRange(offset=3, length=4),
@ -300,12 +302,19 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
"image": ["image_hash1", "image_hash2"], "image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"], "audio": ["audio_hash1", "audio_hash2"],
}, },
expected_modalities=[], expected_modalities=["image", "audio", "image", "audio"],
expected_ranges=[], expected_ranges=[
expected_hashes=None, PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=8, length=2),
PlaceholderRange(offset=11, length=4),
],
expected_hashes=[
"image_hash1", "audio_hash1", "image_hash2", "audio_hash2"
],
), ),
# <image> <image> <video> <audio> <image> # <image> <image> <audio> <video> <image>
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
@ -321,15 +330,54 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
] ]
}, },
mm_hashes=None, mm_hashes=None,
expected_modalities=[], expected_modalities=["image", "image", "audio", "video", "image"],
expected_ranges=[], expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=8, length=5),
PlaceholderRange(offset=20, length=4),
],
expected_hashes=None, expected_hashes=None,
), ),
# <image> <audio> <video> <image> with hashes
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=18, length=4),
],
"audio": [
PlaceholderRange(offset=6, length=2),
],
"video": [
PlaceholderRange(offset=10, length=5),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1"],
"video": ["video_hash1"],
},
expected_modalities=["image", "audio", "video", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=6, length=2),
PlaceholderRange(offset=10, length=5),
PlaceholderRange(offset=18, length=4),
],
expected_hashes=[
"image_hash1", "audio_hash1", "video_hash1", "image_hash2"
],
),
] ]
for case in test_cases: for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
with pytest.raises(ValueError) as ex_info: expected_hashes) in test_cases:
merge_and_sort_multimodal_metadata(case.mm_positions, modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
case.mm_hashes) mm_positions, mm_hashes)
assert "Interleaved mixed-modality" in str(ex_info.value) assert modalities == expected_modalities
assert ranges == expected_ranges
assert hashes == expected_hashes

View File

@ -304,13 +304,9 @@ def merge_and_sort_multimodal_metadata(
Optionally if a MultiModalHashDict is given, same operation will be Optionally if a MultiModalHashDict is given, same operation will be
applied to the object and the sorted list of hashes will be returned. applied to the object and the sorted list of hashes will be returned.
Raises:
ValueError: If the input prompt has interleaved placeholders from
different modalities (e.g, "<image><audio><image> Describe the
content.")
Returns: Returns:
list[str]: Sorted list of involved modalities. list[str]: List of item modalities in order of their positions in
the input sequence.
list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
mm_positions. mm_positions.
Optional[list[str]]: Sorted list of all hashes from mm_hashes if Optional[list[str]]: Sorted list of all hashes from mm_hashes if
@ -324,50 +320,33 @@ def merge_and_sort_multimodal_metadata(
# For single modality, placeholder ranges and hashes are already sorted # For single modality, placeholder ranges and hashes are already sorted
# so we can return the list directly. # so we can return the list directly.
if len(modalities) == 1: if len(modalities) == 1:
if mm_hashes is None: modality = modalities[0]
return modalities, list(mm_positions[modalities[0]]), None placeholder_list = list(mm_positions[modality])
else:
return modalities, list(mm_positions[modalities[0]]), list(
mm_hashes[modalities[0]])
placeholder_lists_with_modality = [(modality, mm_positions[modality]) return [modality] * len(
for modality in modalities] placeholder_list
), placeholder_list, None if not mm_hashes else mm_hashes[modality]
if mm_hashes is None: # Create a list of (modality, placeholder, hash) tuples for all placeholders
sorted_placeholder_lists = sorted(placeholder_lists_with_modality, all_items = []
key=lambda x: x[1][0]['offset']) for modality in modalities:
sorted_hash_lists = None placeholder_list = list(mm_positions[modality])
else: hash_list: list[Optional[str]] = list(
hashes_lists = [ mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [
mm_hashes[modality] for modality in modalities None
if modality in mm_hashes ] * len(placeholder_list)
]
sorted_pairs = sorted(zip(placeholder_lists_with_modality,
hashes_lists),
key=lambda x: x[0][1][0]['offset'])
sorted_placeholder_tuple, sorted_hash_tuple = zip(*sorted_pairs)
sorted_placeholder_lists = list(sorted_placeholder_tuple)
sorted_hash_lists = list(sorted_hash_tuple)
sorted_modalities = [modality for modality, _ in sorted_placeholder_lists] for placeholder, hash_value in zip(placeholder_list, hash_list):
all_items.append((modality, placeholder, hash_value))
# Flatten sorted list of lists to a single list and verify there is no # Sort all items by offset
# interleaving of placeholders from different modalities. all_items.sort(key=lambda x: x[1]['offset'])
merged_placeholders: list[PlaceholderRange] = []
for modality, placeholder_list in sorted_placeholder_lists:
if merged_placeholders and placeholder_list[0][
'offset'] < merged_placeholders[-1]['offset']:
raise ValueError(
"Interleaved mixed-modality inference is currently not "
"supported.")
merged_placeholders.extend(placeholder_list)
if sorted_hash_lists is not None: # Split into separate lists
merged_hashes = [] sorted_modalities = [item[0] for item in all_items]
for hash_list in sorted_hash_lists: merged_placeholders = [item[1] for item in all_items]
merged_hashes.extend(hash_list) merged_hashes = [str(item[2])
else: for item in all_items] if mm_hashes is not None else None
merged_hashes = None
return sorted_modalities, merged_placeholders, merged_hashes return sorted_modalities, merged_placeholders, merged_hashes
@ -383,8 +362,7 @@ def group_mm_inputs_by_modality(
Returns: Returns:
list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each
inner list contains consecutive MultiModalKwargs with same modality, or inner list contains consecutive MultiModalKwargs with same modality.
one with multimodal modalities.
""" """
if not mm_inputs: if not mm_inputs:
return [] return []

View File

@ -234,22 +234,11 @@ class Processor:
if decoder_inputs["type"] == "multimodal": if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"] decoder_mm_inputs = decoder_inputs["mm_kwargs"]
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# contains the kwargs for all items from all modalities.
# This code separates them so that there is one set of kwargs
# per item per modality.
individual_mm_inputs = [
MultiModalKwargs.from_items([item])
for modality in decoder_mm_inputs.modalities
for item in decoder_mm_inputs.get_items(modality)
]
# Merge and flatten multimodal placeholders, hashes and inputs # Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position # from dictionaries to lists, and sort them by each item's position
# in the input sequence. # in the input sequence.
# NOTE: interleaved modalities are not supported.
( (
sorted_modalities, sorted_item_modalities,
sorted_mm_positions, sorted_mm_positions,
sorted_mm_hashes, sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata( ) = merge_and_sort_multimodal_metadata(
@ -257,26 +246,26 @@ class Processor:
decoder_inputs["mm_hashes"] if self.use_hash else None, decoder_inputs["mm_hashes"] if self.use_hash else None,
) )
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple # The output of merged multi-modal processor (`decoder_mm_inputs`)
# modalities involved. # is a single MultiModalKwargs for all items from all modalities.
if len(sorted_modalities) > 1: # This code flattens kwargs for individual items in a list and
modality_order_dict = { # sorts them by each item's position in the input sequence if there
modality: order # are multiple modalities.
for order, modality in enumerate(sorted_modalities) unique_modalities = set(sorted_item_modalities)
} if len(unique_modalities) > 1:
sorted_mm_inputs = []
# Sanity check to make sure each multimodal input has only one used_indices = {modality: 0 for modality in unique_modalities}
# modality key. for modality in sorted_item_modalities:
for mm_input in individual_mm_inputs: items = decoder_mm_inputs.get_items(modality)
assert len(mm_input.modalities) == 1 item = items[used_indices[modality]]
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
# Sort MultiModalKwargs to match sorted_mm_positions ]))
sorted_mm_inputs = sorted( used_indices[modality] += 1
individual_mm_inputs,
key=lambda mm_input: modality_order_dict[list(
mm_input.modalities)[0]])
else: else:
sorted_mm_inputs = individual_mm_inputs sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
return EngineCoreRequest( return EngineCoreRequest(
request_id=request_id, request_id=request_id,