[V1] Support interleaved modality items (#15605)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
6fa7cd3dbc
commit
c67abd614f
@ -431,6 +431,7 @@ steps:
|
||||
- pytest -v -s models/encoder_decoder/audio_language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/vision_language -m core_model
|
||||
- pytest -v -s models/decoder_only/vision_language/test_interleaved.py
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1 # 48m
|
||||
optional: true
|
||||
|
@ -747,30 +747,27 @@ class VllmRunner:
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> list[TextPrompt]:
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
|
||||
if videos is not None:
|
||||
assert len(prompts) == len(videos)
|
||||
if any(x is not None and len(x) != len(prompts)
|
||||
for x in [images, videos, audios]):
|
||||
raise ValueError(
|
||||
"All non-None multimodal inputs must have the same length as "
|
||||
"prompts")
|
||||
|
||||
if audios is not None:
|
||||
assert len(prompts) == len(audios)
|
||||
inputs = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
multi_modal_data = {}
|
||||
if images is not None and (image := images[i]) is not None:
|
||||
multi_modal_data["image"] = image
|
||||
if videos is not None and (video := videos[i]) is not None:
|
||||
multi_modal_data["video"] = video
|
||||
if audios is not None and (audio := audios[i]) is not None:
|
||||
multi_modal_data["audio"] = audio
|
||||
|
||||
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
if image is not None:
|
||||
inputs[i]["multi_modal_data"] = {"image": image}
|
||||
|
||||
if videos is not None:
|
||||
for i, video in enumerate(videos):
|
||||
if video is not None:
|
||||
inputs[i]["multi_modal_data"] = {"video": video}
|
||||
|
||||
if audios is not None:
|
||||
for i, audio in enumerate(audios):
|
||||
if audio is not None:
|
||||
inputs[i]["multi_modal_data"] = {"audio": audio}
|
||||
inputs.append(
|
||||
TextPrompt(prompt=prompt,
|
||||
multi_modal_data=multi_modal_data
|
||||
if multi_modal_data else None))
|
||||
|
||||
return inputs
|
||||
|
||||
|
@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
|
||||
models = ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]
|
||||
|
||||
|
||||
def base_prompt(modalities_str: str) -> str:
|
||||
return f"<|im_start|>user {modalities_str}\nDescribe what you see from these items.<|im_end|><|im_start|>assistant\n" # noqa: E501
|
||||
|
||||
|
||||
INTERLEAVED_PROMPT = base_prompt("<image><video><image>\n")
|
||||
NONINTERLEAVED_PROMPT = base_prompt("<image><image><video>\n")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", ["float16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None:
|
||||
"""
|
||||
This is a simple test to check if interleaved and non-interleaved prompts
|
||||
give the same result.
|
||||
"""
|
||||
|
||||
image_cherry = ImageAsset("cherry_blossom").pil_image.convert("RGB")
|
||||
image_stop = ImageAsset("stop_sign").pil_image.convert("RGB")
|
||||
images = [image_cherry, image_stop]
|
||||
video = VideoAsset(name="sample_demo_1.mp4", num_frames=16).np_ndarrays
|
||||
|
||||
inputs = [
|
||||
(
|
||||
[INTERLEAVED_PROMPT],
|
||||
[images],
|
||||
[video],
|
||||
),
|
||||
(
|
||||
[NONINTERLEAVED_PROMPT],
|
||||
[images],
|
||||
[video],
|
||||
),
|
||||
]
|
||||
|
||||
with vllm_runner(model,
|
||||
task="generate",
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": 2},
|
||||
max_model_len=32768,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=1,
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_outputs_per_case = [
|
||||
vllm_model.generate_greedy(prompts,
|
||||
max_tokens,
|
||||
images=images,
|
||||
videos=videos)
|
||||
for prompts, images, videos in inputs
|
||||
]
|
||||
|
||||
all_results = [output[0][1] for output in vllm_outputs_per_case]
|
||||
outputs = [(total_str, total_str.find("assistant\n") + len("assistant\n"))
|
||||
for total_str in all_results]
|
||||
prompt_lengths = [prompt_len for _, prompt_len in outputs]
|
||||
generated_strs = [
|
||||
total_str[prompt_len:] for total_str, prompt_len in outputs
|
||||
]
|
||||
interleaved_prompt_len, noninterleaved_prompt_len = prompt_lengths
|
||||
interleaved_output_str, noninterleaved_output_str = generated_strs
|
||||
|
||||
# The two prompts are identical except for the order of modality tokens.
|
||||
assert interleaved_prompt_len == noninterleaved_prompt_len
|
||||
|
||||
# The two generated strings should be different because of the
|
||||
# interleaved modality tokens.
|
||||
assert interleaved_output_str != noninterleaved_output_str
|
@ -155,7 +155,7 @@ def test_merge_and_sort_multimodal_metadata():
|
||||
]
|
||||
},
|
||||
mm_hashes={"image": ["hash1", "hash2"]},
|
||||
expected_modalities=["image"],
|
||||
expected_modalities=["image", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=3, length=2),
|
||||
@ -172,7 +172,7 @@ def test_merge_and_sort_multimodal_metadata():
|
||||
]
|
||||
},
|
||||
mm_hashes=None,
|
||||
expected_modalities=["image"],
|
||||
expected_modalities=["image", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=2),
|
||||
@ -197,7 +197,7 @@ def test_merge_and_sort_multimodal_metadata():
|
||||
"image": ["image_hash1", "image_hash2"],
|
||||
"audio": ["audio_hash1", "audio_hash2"],
|
||||
},
|
||||
expected_modalities=["audio", "image"],
|
||||
expected_modalities=["audio", "audio", "image", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
@ -223,7 +223,7 @@ def test_merge_and_sort_multimodal_metadata():
|
||||
]
|
||||
},
|
||||
mm_hashes=None,
|
||||
expected_modalities=["audio", "image"],
|
||||
expected_modalities=["audio", "audio", "image", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
@ -254,7 +254,9 @@ def test_merge_and_sort_multimodal_metadata():
|
||||
"audio": ["audio_hash1"],
|
||||
"video": ["video_hash1", "video_hash2", "video_hash3"]
|
||||
},
|
||||
expected_modalities=["audio", "video", "image"],
|
||||
expected_modalities=[
|
||||
"audio", "video", "video", "video", "image", "image"
|
||||
],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=3, length=4),
|
||||
@ -300,12 +302,19 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
|
||||
"image": ["image_hash1", "image_hash2"],
|
||||
"audio": ["audio_hash1", "audio_hash2"],
|
||||
},
|
||||
expected_modalities=[],
|
||||
expected_ranges=[],
|
||||
expected_hashes=None,
|
||||
expected_modalities=["image", "audio", "image", "audio"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=4),
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
PlaceholderRange(offset=8, length=2),
|
||||
PlaceholderRange(offset=11, length=4),
|
||||
],
|
||||
expected_hashes=[
|
||||
"image_hash1", "audio_hash1", "image_hash2", "audio_hash2"
|
||||
],
|
||||
),
|
||||
|
||||
# <image> <image> <video> <audio> <image>
|
||||
# <image> <image> <audio> <video> <image>
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
@ -321,15 +330,54 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
|
||||
]
|
||||
},
|
||||
mm_hashes=None,
|
||||
expected_modalities=[],
|
||||
expected_ranges=[],
|
||||
expected_modalities=["image", "image", "audio", "video", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
PlaceholderRange(offset=8, length=5),
|
||||
PlaceholderRange(offset=20, length=4),
|
||||
],
|
||||
expected_hashes=None,
|
||||
),
|
||||
|
||||
# <image> <audio> <video> <image> with hashes
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=18, length=4),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=6, length=2),
|
||||
],
|
||||
"video": [
|
||||
PlaceholderRange(offset=10, length=5),
|
||||
]
|
||||
},
|
||||
mm_hashes={
|
||||
"image": ["image_hash1", "image_hash2"],
|
||||
"audio": ["audio_hash1"],
|
||||
"video": ["video_hash1"],
|
||||
},
|
||||
expected_modalities=["image", "audio", "video", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=6, length=2),
|
||||
PlaceholderRange(offset=10, length=5),
|
||||
PlaceholderRange(offset=18, length=4),
|
||||
],
|
||||
expected_hashes=[
|
||||
"image_hash1", "audio_hash1", "video_hash1", "image_hash2"
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
for case in test_cases:
|
||||
with pytest.raises(ValueError) as ex_info:
|
||||
merge_and_sort_multimodal_metadata(case.mm_positions,
|
||||
case.mm_hashes)
|
||||
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
|
||||
expected_hashes) in test_cases:
|
||||
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
|
||||
mm_positions, mm_hashes)
|
||||
|
||||
assert "Interleaved mixed-modality" in str(ex_info.value)
|
||||
assert modalities == expected_modalities
|
||||
assert ranges == expected_ranges
|
||||
assert hashes == expected_hashes
|
||||
|
@ -303,14 +303,10 @@ def merge_and_sort_multimodal_metadata(
|
||||
|
||||
Optionally if a MultiModalHashDict is given, same operation will be
|
||||
applied to the object and the sorted list of hashes will be returned.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input prompt has interleaved placeholders from
|
||||
different modalities (e.g, "<image><audio><image> Describe the
|
||||
content.")
|
||||
|
||||
Returns:
|
||||
list[str]: Sorted list of involved modalities.
|
||||
list[str]: List of item modalities in order of their positions in
|
||||
the input sequence.
|
||||
list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
|
||||
mm_positions.
|
||||
Optional[list[str]]: Sorted list of all hashes from mm_hashes if
|
||||
@ -324,50 +320,33 @@ def merge_and_sort_multimodal_metadata(
|
||||
# For single modality, placeholder ranges and hashes are already sorted
|
||||
# so we can return the list directly.
|
||||
if len(modalities) == 1:
|
||||
if mm_hashes is None:
|
||||
return modalities, list(mm_positions[modalities[0]]), None
|
||||
else:
|
||||
return modalities, list(mm_positions[modalities[0]]), list(
|
||||
mm_hashes[modalities[0]])
|
||||
modality = modalities[0]
|
||||
placeholder_list = list(mm_positions[modality])
|
||||
|
||||
placeholder_lists_with_modality = [(modality, mm_positions[modality])
|
||||
for modality in modalities]
|
||||
return [modality] * len(
|
||||
placeholder_list
|
||||
), placeholder_list, None if not mm_hashes else mm_hashes[modality]
|
||||
|
||||
if mm_hashes is None:
|
||||
sorted_placeholder_lists = sorted(placeholder_lists_with_modality,
|
||||
key=lambda x: x[1][0]['offset'])
|
||||
sorted_hash_lists = None
|
||||
else:
|
||||
hashes_lists = [
|
||||
mm_hashes[modality] for modality in modalities
|
||||
if modality in mm_hashes
|
||||
]
|
||||
sorted_pairs = sorted(zip(placeholder_lists_with_modality,
|
||||
hashes_lists),
|
||||
key=lambda x: x[0][1][0]['offset'])
|
||||
sorted_placeholder_tuple, sorted_hash_tuple = zip(*sorted_pairs)
|
||||
sorted_placeholder_lists = list(sorted_placeholder_tuple)
|
||||
sorted_hash_lists = list(sorted_hash_tuple)
|
||||
# Create a list of (modality, placeholder, hash) tuples for all placeholders
|
||||
all_items = []
|
||||
for modality in modalities:
|
||||
placeholder_list = list(mm_positions[modality])
|
||||
hash_list: list[Optional[str]] = list(
|
||||
mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [
|
||||
None
|
||||
] * len(placeholder_list)
|
||||
|
||||
sorted_modalities = [modality for modality, _ in sorted_placeholder_lists]
|
||||
for placeholder, hash_value in zip(placeholder_list, hash_list):
|
||||
all_items.append((modality, placeholder, hash_value))
|
||||
|
||||
# Flatten sorted list of lists to a single list and verify there is no
|
||||
# interleaving of placeholders from different modalities.
|
||||
merged_placeholders: list[PlaceholderRange] = []
|
||||
for modality, placeholder_list in sorted_placeholder_lists:
|
||||
if merged_placeholders and placeholder_list[0][
|
||||
'offset'] < merged_placeholders[-1]['offset']:
|
||||
raise ValueError(
|
||||
"Interleaved mixed-modality inference is currently not "
|
||||
"supported.")
|
||||
merged_placeholders.extend(placeholder_list)
|
||||
# Sort all items by offset
|
||||
all_items.sort(key=lambda x: x[1]['offset'])
|
||||
|
||||
if sorted_hash_lists is not None:
|
||||
merged_hashes = []
|
||||
for hash_list in sorted_hash_lists:
|
||||
merged_hashes.extend(hash_list)
|
||||
else:
|
||||
merged_hashes = None
|
||||
# Split into separate lists
|
||||
sorted_modalities = [item[0] for item in all_items]
|
||||
merged_placeholders = [item[1] for item in all_items]
|
||||
merged_hashes = [str(item[2])
|
||||
for item in all_items] if mm_hashes is not None else None
|
||||
|
||||
return sorted_modalities, merged_placeholders, merged_hashes
|
||||
|
||||
@ -383,8 +362,7 @@ def group_mm_inputs_by_modality(
|
||||
|
||||
Returns:
|
||||
list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each
|
||||
inner list contains consecutive MultiModalKwargs with same modality, or
|
||||
one with multimodal modalities.
|
||||
inner list contains consecutive MultiModalKwargs with same modality.
|
||||
"""
|
||||
if not mm_inputs:
|
||||
return []
|
||||
|
@ -234,22 +234,11 @@ class Processor:
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
|
||||
|
||||
# The output of merged multi-modal processor (`decoder_mm_inputs`)
|
||||
# contains the kwargs for all items from all modalities.
|
||||
# This code separates them so that there is one set of kwargs
|
||||
# per item per modality.
|
||||
individual_mm_inputs = [
|
||||
MultiModalKwargs.from_items([item])
|
||||
for modality in decoder_mm_inputs.modalities
|
||||
for item in decoder_mm_inputs.get_items(modality)
|
||||
]
|
||||
|
||||
# Merge and flatten multimodal placeholders, hashes and inputs
|
||||
# from dictionaries to lists, and sort them by each item's position
|
||||
# in the input sequence.
|
||||
# NOTE: interleaved modalities are not supported.
|
||||
(
|
||||
sorted_modalities,
|
||||
sorted_item_modalities,
|
||||
sorted_mm_positions,
|
||||
sorted_mm_hashes,
|
||||
) = merge_and_sort_multimodal_metadata(
|
||||
@ -257,26 +246,26 @@ class Processor:
|
||||
decoder_inputs["mm_hashes"] if self.use_hash else None,
|
||||
)
|
||||
|
||||
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
|
||||
# modalities involved.
|
||||
if len(sorted_modalities) > 1:
|
||||
modality_order_dict = {
|
||||
modality: order
|
||||
for order, modality in enumerate(sorted_modalities)
|
||||
}
|
||||
|
||||
# Sanity check to make sure each multimodal input has only one
|
||||
# modality key.
|
||||
for mm_input in individual_mm_inputs:
|
||||
assert len(mm_input.modalities) == 1
|
||||
|
||||
# Sort MultiModalKwargs to match sorted_mm_positions
|
||||
sorted_mm_inputs = sorted(
|
||||
individual_mm_inputs,
|
||||
key=lambda mm_input: modality_order_dict[list(
|
||||
mm_input.modalities)[0]])
|
||||
# The output of merged multi-modal processor (`decoder_mm_inputs`)
|
||||
# is a single MultiModalKwargs for all items from all modalities.
|
||||
# This code flattens kwargs for individual items in a list and
|
||||
# sorts them by each item's position in the input sequence if there
|
||||
# are multiple modalities.
|
||||
unique_modalities = set(sorted_item_modalities)
|
||||
if len(unique_modalities) > 1:
|
||||
sorted_mm_inputs = []
|
||||
used_indices = {modality: 0 for modality in unique_modalities}
|
||||
for modality in sorted_item_modalities:
|
||||
items = decoder_mm_inputs.get_items(modality)
|
||||
item = items[used_indices[modality]]
|
||||
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
|
||||
]))
|
||||
used_indices[modality] += 1
|
||||
else:
|
||||
sorted_mm_inputs = individual_mm_inputs
|
||||
sorted_mm_inputs = [
|
||||
MultiModalKwargs.from_items([item]) for item in
|
||||
decoder_mm_inputs.get_items(sorted_item_modalities[0])
|
||||
]
|
||||
|
||||
return EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
|
Loading…
x
Reference in New Issue
Block a user