[Core][VLM] Add precise multi-modal placeholder tracking (#8346)

Signed-off-by: Peter Salas <peter@fixie.ai>
This commit is contained in:
Peter Salas 2024-11-01 16:21:10 -07:00 committed by GitHub
parent d151fde834
commit 6c0b7f548d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
53 changed files with 913 additions and 281 deletions

View File

@ -34,11 +34,7 @@ def run_ultravox(question: str, audio_count: int):
tokenize=False, tokenize=False,
add_generation_prompt=True) add_generation_prompt=True)
llm = LLM(model=model_name, llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
enforce_eager=True,
enable_chunked_prefill=False,
max_model_len=8192,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids

View File

@ -869,6 +869,7 @@ def make_test_metadata(
return attn_backend.make_metadata( return attn_backend.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
@ -914,6 +915,7 @@ def make_test_metadata(
return attn_backend.make_metadata( return attn_backend.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping, slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,

View File

@ -2,8 +2,10 @@ from typing import List, Optional, Tuple, Type
import numpy as np import numpy as np
import pytest import pytest
import pytest_asyncio
from transformers import AutoModel, AutoTokenizer, BatchEncoding from transformers import AutoModel, AutoTokenizer, BatchEncoding
from tests.utils import RemoteOpenAIServer
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
@ -17,6 +19,13 @@ AudioTuple = Tuple[np.ndarray, int]
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
HF_PLACEHOLDER = "<|audio|>" HF_PLACEHOLDER = "<|audio|>"
CHUNKED_PREFILL_KWARGS = {
"enable_chunked_prefill": True,
"max_num_seqs": 2,
# Use a very small limit to exercise chunked prefill.
"max_num_batched_tokens": 16
}
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def audio_assets(): def audio_assets():
@ -30,6 +39,26 @@ def audio(request):
return AudioAsset(request.param) return AudioAsset(request.param)
@pytest.fixture(params=({}, CHUNKED_PREFILL_KWARGS))
def server(request, audio_assets):
args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
] + [
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
def _get_prompt(audio_count, question, placeholder): def _get_prompt(audio_count, question, placeholder):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
placeholder = f"{placeholder}\n" * audio_count placeholder = f"{placeholder}\n" * audio_count
@ -68,8 +97,7 @@ def run_test(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
tensor_parallel_size: int, **kwargs,
distributed_executor_backend: Optional[str] = None,
): ):
"""Inference result should be the same between hf and vllm.""" """Inference result should be the same between hf and vllm."""
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
@ -79,11 +107,8 @@ def run_test(
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method). # will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model, with vllm_runner(model, dtype=dtype, enforce_eager=True,
dtype=dtype, **kwargs) as vllm_model:
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_audio = [ vllm_outputs_per_audio = [
vllm_model.generate_greedy_logprobs([vllm_prompt], vllm_model.generate_greedy_logprobs([vllm_prompt],
max_tokens, max_tokens,
@ -135,18 +160,16 @@ def run_multi_audio_test(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
tensor_parallel_size: int, **kwargs,
distributed_executor_backend: Optional[str] = None,
): ):
with vllm_runner(model, with vllm_runner(model,
dtype=dtype, dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True, enforce_eager=True,
limit_mm_per_prompt={ limit_mm_per_prompt={
"audio": "audio":
max((len(audio) for _, audio in prompts_and_audios)) max((len(audio) for _, audio in prompts_and_audios))
}) as vllm_model: },
**kwargs) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
[prompt for prompt, _ in prompts_and_audios], [prompt for prompt, _ in prompts_and_audios],
max_tokens, max_tokens,
@ -162,8 +185,9 @@ def run_multi_audio_test(
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
num_logprobs: int) -> None: num_logprobs: int, vllm_kwargs: dict) -> None:
vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER) vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER) hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
@ -175,7 +199,7 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
dtype=dtype, dtype=dtype,
max_tokens=max_tokens, max_tokens=max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
tensor_parallel_size=1, **vllm_kwargs,
) )
@ -183,9 +207,10 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
max_tokens: int, max_tokens: int, num_logprobs: int,
num_logprobs: int) -> None: vllm_kwargs: dict) -> None:
vllm_prompt = _get_prompt(len(audio_assets), vllm_prompt = _get_prompt(len(audio_assets),
"Describe each of the audios above.", "Describe each of the audios above.",
@ -198,5 +223,37 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
dtype=dtype, dtype=dtype,
max_tokens=max_tokens, max_tokens=max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
tensor_parallel_size=1, **vllm_kwargs,
) )
@pytest.mark.asyncio
async def test_online_inference(client, audio_assets):
"""Exercises online inference with/without chunked prefill enabled."""
messages = [{
"role":
"user",
"content": [
*[{
"type": "audio_url",
"audio_url": {
"url": audio.url
}
} for audio in audio_assets],
{
"type":
"text",
"text":
f"What's happening in these {len(audio_assets)} audio clips?"
},
],
}]
chat_completion = await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
max_tokens=10)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"

View File

@ -5,8 +5,8 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
from vllm.inputs.registry import InputRegistry InputRegistry, token_inputs)
from vllm.multimodal import MultiModalRegistry from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
@ -56,7 +56,7 @@ def use_dummy_data_mock():
num_crops=DEFAULT_NUM_CROPS): num_crops=DEFAULT_NUM_CROPS):
seq_data = SequenceData( seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
return seq_data, None return DummyData(seq_data, None)
with patch( with patch(
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory", "vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
# NOTE: seq_len is thrown away here since this will leverage the # NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq # default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs. # len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling( dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry) ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == expected_seq_count assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
# NOTE: seq_len is thrown away here since this will leverage the # NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq # default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs. # len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling( dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry) ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
### Test overrides for the max token count per multimodal instance ### Test overrides for the max token count per multimodal instance

View File

@ -92,18 +92,50 @@ def test_repeat_and_pad_placeholder_tokens(model):
tokenizer = AutoTokenizer.from_pretrained(model) tokenizer = AutoTokenizer.from_pretrained(model)
test_cases = [ test_cases = [
("<image>", 2, "<image><image>", [32000, 32000]), (
("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]), "<image>",
("<image><image>", [3, 2], "<image><image><image><image><image>", 2,
[32000, 32000, 32000, 32000, 32000]), "<image><image>",
("Image:<image>Image:<image>!", [3, 2], [32000, 32000],
[{ "offset": 0, "length": 2 }],
),
(
"<image><image>",
2,
"<image><image><image>",
[32000, 32000, 32000],
[{ "offset": 0, "length": 2 }]),
(
"<image><image>",
[3, 2],
"<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000],
[{ "offset": 0, "length": 3 }, { "offset": 3, "length": 2 }],
),
(
"Image:<image>Image:<image>!",
[3, 2],
"Image:<image><image><image>Image:<image><image>!", "Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]), [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]), [{ "offset": 2, "length": 3 }, { "offset": 7, "length": 2 }],
] ),
(
"<image>",
[3, 2],
"<image><image><image>",
[32000, 32000, 32000],
[{ "offset": 0, "length": 3 }],
),
] # yapf: disable
for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases: for (
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( prompt,
repeat_count,
expected_prompt,
expected_token_ids,
expected_ranges,
) in test_cases:
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer=tokenizer, tokenizer=tokenizer,
prompt=prompt, prompt=prompt,
prompt_token_ids=tokenizer.encode(prompt, prompt_token_ids=tokenizer.encode(prompt,
@ -113,3 +145,4 @@ def test_repeat_and_pad_placeholder_tokens(model):
) )
assert new_prompt == expected_prompt assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids assert new_token_ids == expected_token_ids
assert ranges == expected_ranges

View File

@ -73,6 +73,7 @@ def test_model_runner_input():
num_prefill_tokens=2, num_prefill_tokens=2,
num_decode_tokens=3, num_decode_tokens=3,
slot_mapping=torch.zeros(1), slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
) )
model_input = ModelInputForGPUWithSamplingMetadata( model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10), input_tokens=torch.ones(10),
@ -124,6 +125,7 @@ def test_embedding_model_runner_input():
num_prefill_tokens=2, num_prefill_tokens=2,
num_decode_tokens=3, num_decode_tokens=3,
slot_mapping=torch.zeros(1), slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
) )
model_input = ModelInputForGPUWithPoolingMetadata( model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10), input_tokens=torch.ones(10),
@ -174,6 +176,7 @@ def test_multi_step_model_runner_input():
num_prefill_tokens=2, num_prefill_tokens=2,
num_decode_tokens=3, num_decode_tokens=3,
slot_mapping=torch.zeros(1), slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
) )
frozen_model_input = ModelInputForGPUWithSamplingMetadata( frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10), input_tokens=torch.ones(10),

View File

@ -7,6 +7,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
import torch import torch
from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner_base import (ModelRunnerBase, from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase, ModelRunnerInputBase,
@ -108,6 +110,15 @@ class AttentionMetadata:
# in block 0, and 1st slot in block 1, respectively. # in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
# The index maps that relate multi-modal embeddings to the corresponding
# placeholders.
#
# N.B. These aren't really related to attention and don't belong on this
# type -- this is just a temporary solution to make them available to
# `model_executable`.
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]
@property @property
@abstractmethod @abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]: def prefill_metadata(self) -> Optional["AttentionMetadata"]:

View File

@ -215,6 +215,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens], slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
@ -243,6 +245,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None, max_query_len=None,

View File

@ -1,4 +1,5 @@
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
@ -14,6 +15,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import (async_tensor_h2d, direct_register_custom_op, from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
make_tensor_with_pad) make_tensor_with_pad)
@ -169,6 +171,8 @@ class FlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens], slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
@ -198,6 +202,7 @@ class FlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_decode_query_len=self.max_decode_query_len, max_decode_query_len=self.max_decode_query_len,
@ -297,6 +302,9 @@ class FlashAttentionMetadataBuilder(
self.context_lens: List[int] = [] self.context_lens: List[int] = []
self.block_tables: List[List[int]] = [] self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = [] self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0 self.num_prefills = 0
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
@ -327,6 +335,12 @@ class FlashAttentionMetadataBuilder(
self.context_lens.append(context_len) self.context_lens.append(context_len)
if is_prompt: if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1 self.num_prefills += 1
self.num_prefill_tokens += token_len self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len) self.prefill_seq_lens.append(seq_len)
@ -449,6 +463,11 @@ class FlashAttentionMetadataBuilder(
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
torch.cumsum(seq_lens_tensor, torch.cumsum(seq_lens_tensor,
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
@ -464,6 +483,7 @@ class FlashAttentionMetadataBuilder(
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len, max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len, max_decode_query_len=max_decode_query_len,

View File

@ -1,7 +1,10 @@
from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
from vllm.multimodal import MultiModalPlaceholderMap
try: try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
@ -215,6 +218,7 @@ class FlashInferState(AttentionState):
attn_metadata = self.runner.attn_backend.make_metadata( attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0, num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size], slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=batch_size, num_decode_tokens=batch_size,
max_prefill_seq_len=0, max_prefill_seq_len=0,
@ -470,6 +474,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.context_lens: List[int] = [] self.context_lens: List[int] = []
self.block_tables: List[List[int]] = [] self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = [] self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0 self.num_prefills = 0
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
@ -519,6 +526,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
inter_data.curr_sliding_window_blocks): inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len) self.context_lens.append(context_len)
if is_prompt: if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1 self.num_prefills += 1
self.num_prefill_tokens += token_len self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len) self.prefill_seq_lens.append(seq_len)
@ -651,6 +663,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
torch.cumsum(seq_lens_tensor, torch.cumsum(seq_lens_tensor,
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
@ -694,6 +711,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
decode_query_len=decode_query_len, decode_query_len=decode_query_len,
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,

View File

@ -1,5 +1,6 @@
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
import torch import torch
@ -7,6 +8,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder) AttentionMetadataBuilder)
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder from vllm.worker.model_runner import ModelInputForGPUBuilder
@ -135,6 +137,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_decode_query_len=0, max_decode_query_len=0,
@ -167,6 +171,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_decode_query_len=self.max_decode_query_len, max_decode_query_len=self.max_decode_query_len,
@ -189,6 +194,9 @@ class PlaceholderAttentionMetadataBuilder(
self.prefill_seq_lens: List[int] = [] self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = [] self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = [] self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0 self.num_prefills = 0
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
@ -213,6 +221,12 @@ class PlaceholderAttentionMetadataBuilder(
self.context_lens.append(context_len) self.context_lens.append(context_len)
if is_prompt: if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1 self.num_prefills += 1
self.num_prefill_tokens += token_len self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len) self.prefill_seq_lens.append(seq_len)
@ -280,6 +294,11 @@ class PlaceholderAttentionMetadataBuilder(
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
torch.cumsum(seq_lens_tensor, torch.cumsum(seq_lens_tensor,
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
@ -296,6 +315,7 @@ class PlaceholderAttentionMetadataBuilder(
return PlaceholderAttentionMetadata( return PlaceholderAttentionMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,

View File

@ -150,6 +150,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens], slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
@ -178,6 +180,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None, max_query_len=None,

View File

@ -1,4 +1,5 @@
"""Attention backend utils""" """Attention backend utils"""
from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
@ -7,6 +8,7 @@ import torch
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState) AttentionState)
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING: if TYPE_CHECKING:
@ -123,6 +125,9 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.context_lens: List[int] = [] self.context_lens: List[int] = []
self.block_tables: List[List[int]] = [] self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = [] self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0 self.num_prefills = 0
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
@ -147,6 +152,12 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
inter_data.curr_sliding_window_blocks): inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len) self.context_lens.append(context_len)
if is_prompt: if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1 self.num_prefills += 1
self.num_prefill_tokens += token_len self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len) self.prefill_seq_lens.append(seq_len)
@ -242,6 +253,11 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
torch.cumsum(seq_lens_tensor, torch.cumsum(seq_lens_tensor,
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
@ -254,6 +270,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
return self._metadata_cls( # type: ignore return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
@ -305,6 +322,7 @@ class CommonAttentionState(AttentionState):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=batch_size, num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size], slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size], seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1, max_query_len=1,

View File

@ -212,6 +212,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
@ -255,6 +257,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len, max_decode_seq_len=self.max_decode_seq_len,

View File

@ -1308,6 +1308,8 @@ class Scheduler:
# `multi_modal_data` will be None. # `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None, if scheduler_outputs.num_prefill_groups > 0 else None,
multi_modal_placeholders=seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs, mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request, prompt_adapter_request=seq_group.prompt_adapter_request,
) )

View File

@ -3,7 +3,7 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
token_inputs, zip_enc_dec_prompts) token_inputs, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry from .registry import DummyData, InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
""" """
@ -29,6 +29,7 @@ __all__ = [
"to_enc_dec_tuple_list", "to_enc_dec_tuple_list",
"zip_enc_dec_prompts", "zip_enc_dec_prompts",
"INPUT_REGISTRY", "INPUT_REGISTRY",
"DummyData",
"InputContext", "InputContext",
"InputRegistry", "InputRegistry",
] ]

View File

@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
from typing_extensions import NotRequired, TypedDict, TypeVar from typing_extensions import NotRequired, TypedDict, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
class TextPrompt(TypedDict): class TextPrompt(TypedDict):
@ -136,6 +136,12 @@ class TokenInputs(TypedDict):
if the model supports it. if the model supports it.
""" """
multi_modal_placeholders: NotRequired[
Optional["MultiModalPlaceholderDict"]]
"""
Placeholder ranges for the multi-modal data.
"""
mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]]
""" """
Optional multi-modal processor kwargs to be forwarded to the Optional multi-modal processor kwargs to be forwarded to the
@ -149,6 +155,7 @@ def token_inputs(
prompt_token_ids: List[int], prompt_token_ids: List[int],
prompt: Optional[str] = None, prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs: ) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values.""" """Construct :class:`TokenInputs` from optional values."""
@ -158,6 +165,8 @@ def token_inputs(
inputs["prompt"] = prompt inputs["prompt"] = prompt
if multi_modal_data is not None: if multi_modal_data is not None:
inputs["multi_modal_data"] = multi_modal_data inputs["multi_modal_data"] = multi_modal_data
if multi_modal_placeholders is not None:
inputs["multi_modal_placeholders"] = multi_modal_placeholders
if mm_processor_kwargs is not None: if mm_processor_kwargs is not None:
inputs["mm_processor_kwargs"] = mm_processor_kwargs inputs["mm_processor_kwargs"] = mm_processor_kwargs

View File

@ -1,8 +1,8 @@
import functools import functools
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
Protocol, Tuple, Type) Optional, Protocol, Type)
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -16,7 +16,8 @@ from .data import DecoderOnlyInputs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.multimodal import MultiModalDataDict, MultiModalRegistry from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
MultiModalRegistry)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
logger = init_logger(__name__) logger = init_logger(__name__)
@ -63,6 +64,14 @@ class InputContext:
N = TypeVar("N", bound=Type[nn.Module]) N = TypeVar("N", bound=Type[nn.Module])
class DummyData(NamedTuple):
"""Dummy data used for profiling."""
seq_data: "SequenceData"
multi_modal_data: Optional["MultiModalDataDict"] = None
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None
class DummyDataFactory(Protocol): class DummyDataFactory(Protocol):
def __call__( def __call__(
@ -71,7 +80,7 @@ class DummyDataFactory(Protocol):
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
**mm_processor_kwargs: Any, **mm_processor_kwargs: Any,
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: ) -> DummyData:
""" """
Create dummy data to be inputted into the model. Create dummy data to be inputted into the model.
@ -123,7 +132,7 @@ class InputRegistry:
ctx: InputContext, ctx: InputContext,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: ) -> DummyData:
""" """
The default dummy data factory represents the longest possible text The default dummy data factory represents the longest possible text
that can be inputted to the model. that can be inputted to the model.
@ -134,10 +143,7 @@ class InputRegistry:
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
dummy_seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) return DummyData(SequenceData.from_prompt_token_counts((0, seq_len)))
dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data
def register_dummy_data(self, factory: DummyDataFactory): def register_dummy_data(self, factory: DummyDataFactory):
""" """
@ -195,7 +201,7 @@ class InputRegistry:
seq_len: int, seq_len: int,
mm_registry: "MultiModalRegistry", mm_registry: "MultiModalRegistry",
is_encoder_data: bool = False, is_encoder_data: bool = False,
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: ) -> DummyData:
""" """
Create dummy data for profiling the memory usage of a model. Create dummy data for profiling the memory usage of a model.
@ -220,12 +226,12 @@ class InputRegistry:
mm_processor_kwargs = get_allowed_kwarg_only_overrides( mm_processor_kwargs = get_allowed_kwarg_only_overrides(
dummy_factory, overrides=model_config.mm_processor_kwargs) dummy_factory, overrides=model_config.mm_processor_kwargs)
seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, dummy_data = dummy_factory(InputContext(model_config), seq_len,
_MultiModalCounts(mm_counts), _MultiModalCounts(mm_counts),
**mm_processor_kwargs) **mm_processor_kwargs)
# Having more tokens is over-conservative but otherwise fine # Having more tokens is over-conservative but otherwise fine
num_tokens = seq_data.prompt_token_ids num_tokens = dummy_data.seq_data.prompt_token_ids
if len(num_tokens) < seq_len: if len(num_tokens) < seq_len:
if is_encoder_data: if is_encoder_data:
print_warning_once( print_warning_once(
@ -235,15 +241,15 @@ class InputRegistry:
raise AssertionError( raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, " f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.") f"but found {len(num_tokens)} tokens instead.")
if mm_data is not None: if dummy_data.multi_modal_data is not None:
for k, v in mm_data.items(): for k, v in dummy_data.multi_modal_data.items():
num_items = len(v) if isinstance(v, list) else 1 num_items = len(v) if isinstance(v, list) else 1
num_expected = mm_counts[k] num_expected = mm_counts[k]
assert num_items >= num_expected, ( assert num_items >= num_expected, (
f"Expected at least {num_expected} dummy '{k}' instances " f"Expected at least {num_expected} dummy '{k}' instances "
f"for profiling, but found {num_items} instances instead.") f"for profiling, but found {num_items} instances instead.")
return seq_data, mm_data return dummy_data
def _default_input_processor( def _default_input_processor(
self, self,

View File

@ -98,6 +98,11 @@ def input_processor_for_blip(
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return inputs return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None: if image_feature_size_override is None:
@ -105,7 +110,7 @@ def input_processor_for_blip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
@ -116,7 +121,8 @@ def input_processor_for_blip(
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa

View File

@ -9,13 +9,14 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip, from .blip import (BlipVisionModel, dummy_image_for_blip,
@ -425,7 +426,11 @@ def dummy_seq_data_for_blip2(
return SequenceData.from_prompt_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) ), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_data_for_blip2(ctx: InputContext, seq_len: int, def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
@ -434,7 +439,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
num_images = mm_counts["image"] num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_blip2( seq_data, ranges = dummy_seq_data_for_blip2(
hf_config, hf_config,
seq_len, seq_len,
num_images, num_images,
@ -444,7 +449,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
if isinstance(vision_config, Blip2VisionConfig): if isinstance(vision_config, Blip2VisionConfig):
mm_data = dummy_image_for_blip(vision_config, num_images) mm_data = dummy_image_for_blip(vision_config, num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@ -11,8 +11,8 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -30,6 +30,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
@ -73,7 +74,11 @@ def dummy_seq_data_for_chameleon(
return SequenceData.from_prompt_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) ), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_chameleon( def dummy_image_for_chameleon(
@ -97,14 +102,14 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]): mm_counts: Mapping[str, int]):
num_images = mm_counts["image"] num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_chameleon( seq_data, ranges = dummy_seq_data_for_chameleon(
seq_len, seq_len,
num_images, num_images,
image_token_id=CHAMELEON_IMAGE_TOKEN_ID, image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
) )
mm_data = dummy_image_for_chameleon(num_images) mm_data = dummy_image_for_chameleon(num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
def input_processor_for_chameleon(ctx: InputContext, def input_processor_for_chameleon(ctx: InputContext,
@ -120,9 +125,14 @@ def input_processor_for_chameleon(ctx: InputContext,
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return inputs return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
model_config = ctx.model_config model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],

View File

@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
@ -49,14 +50,13 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
return get_clip_image_feature_size(hf_config) return get_clip_image_feature_size(hf_config)
def dummy_seq_data_for_clip( def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig,
hf_config: CLIPVisionConfig,
seq_len: int, seq_len: int,
num_images: int, num_images: int,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[int] = None,
): mm_key: str = "image"):
if image_feature_size_override is None: if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config) image_feature_size = get_clip_image_feature_size(hf_config)
else: else:
@ -65,7 +65,11 @@ def dummy_seq_data_for_clip(
return SequenceData.from_prompt_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) ), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_clip( def dummy_image_for_clip(
@ -117,6 +121,11 @@ def input_processor_for_clip(
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return inputs return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None: if image_feature_size_override is None:
@ -130,7 +139,7 @@ def input_processor_for_clip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
@ -141,7 +150,8 @@ def input_processor_for_clip(
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa

View File

@ -27,8 +27,8 @@ from transformers import FuyuConfig, FuyuImageProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
@ -37,9 +37,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
@ -103,7 +105,11 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images) [0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_fuyu( def dummy_image_for_fuyu(
@ -119,15 +125,15 @@ def dummy_image_for_fuyu(
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int, def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]): mm_counts: Mapping[str, int]):
num_images = mm_counts["image"] num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images) seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
mm_data = dummy_image_for_fuyu(num_images, mm_data = dummy_image_for_fuyu(num_images,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT) image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
data: Image.Image): data: List[Image.Image]):
image_encoding = image_processor.preprocess(data, return_tensors="pt") image_encoding = image_processor.preprocess(data, return_tensors="pt")
batch_images = torch.stack([img[0] for img in image_encoding["images"] batch_images = torch.stack([img[0] for img in image_encoding["images"]
]).unsqueeze(1) ]).unsqueeze(1)
@ -158,8 +164,10 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
model_config = ctx.model_config model_config = ctx.model_config
image_data = multi_modal_data["image"] image_data = multi_modal_data["image"]
new_multi_modal_data = {} new_multi_modal_data = {}
image_list = image_data if isinstance(image_data, list) else [image_data]
# process image data # process image data
if isinstance(image_data, Image.Image): if is_list_of(image_list, Image.Image):
# Fuyu's image_processor can also finish token padding # Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor( image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model) model_config.model)
@ -171,7 +179,7 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
]) ])
new_multi_modal_data["image"] = image_patches new_multi_modal_data["image"] = image_patches
elif isinstance(image_data, torch.Tensor): elif is_list_of(image_list, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet") raise NotImplementedError("Embeddings input is not supported yet")
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
@ -198,12 +206,13 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
def input_mapper_for_fuyu(ctx: InputContext, data: object): def input_mapper_for_fuyu(ctx: InputContext, data: object):
model_config = ctx.model_config model_config = ctx.model_config
if isinstance(data, Image.Image): data_list = data if isinstance(data, list) else [data]
if is_list_of(data_list, Image.Image):
# Fuyu's image_processor can also finish token padding # Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor( image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model) model_config.model)
model_image_input = _fuyu_image_preprocess(image_processor, data) model_image_input = _fuyu_image_preprocess(image_processor, data_list)
data = torch.stack([ data = torch.stack([
image_patch[0] image_patch[0]
for image_patch in model_image_input["image_patches"] for image_patch in model_image_input["image_patches"]

View File

@ -17,8 +17,8 @@ from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.quantization import (AWQConfig, from vllm.model_executor.layers.quantization import (AWQConfig,
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
@ -379,7 +379,7 @@ class InternVLInputPipeline:
model_config.tokenizer, model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code) trust_remote_code=model_config.trust_remote_code)
seq_data = dummy_seq_data_for_clip( seq_data, ranges = dummy_seq_data_for_clip(
hf_config.vision_config, hf_config.vision_config,
seq_len, seq_len,
num_images, num_images,
@ -398,7 +398,7 @@ class InternVLInputPipeline:
image_height_override=max_image_height, image_height_override=max_image_height,
) )
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT) input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)

View File

@ -10,7 +10,8 @@ from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
@ -111,7 +112,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
image_feature_size = get_max_llava_image_tokens(ctx) image_feature_size = get_max_llava_image_tokens(ctx)
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip( seq_data, ranges = dummy_seq_data_for_clip(
vision_config, vision_config,
seq_len, seq_len,
num_images, num_images,
@ -120,9 +121,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
) )
mm_data = dummy_image_for_clip(vision_config, num_images) mm_data = dummy_image_for_clip(vision_config, num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig): elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip( seq_data, ranges = dummy_seq_data_for_siglip(
vision_config, vision_config,
seq_len, seq_len,
num_images, num_images,
@ -131,9 +132,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
) )
mm_data = dummy_image_for_siglip(vision_config, num_images) mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, PixtralVisionConfig): elif isinstance(vision_config, PixtralVisionConfig):
seq_data = dummy_seq_data_for_pixtral_hf( seq_data, ranges = dummy_seq_data_for_pixtral_hf(
vision_config, vision_config,
seq_len, seq_len,
num_images, num_images,
@ -142,7 +143,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
) )
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images) mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@ -12,7 +12,8 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
@ -180,7 +181,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
max_feat_height, max_feat_width = pinpoint max_feat_height, max_feat_width = pinpoint
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip( seq_data, ranges = dummy_seq_data_for_clip(
vision_config, vision_config,
seq_len, seq_len,
num_images, num_images,
@ -195,9 +196,9 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
image_height_override=max_feat_height, image_height_override=max_feat_height,
) )
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig): elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip( seq_data, ranges = dummy_seq_data_for_siglip(
vision_config, vision_config,
seq_len, seq_len,
num_images, num_images,
@ -212,7 +213,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
image_height_override=max_feat_height, image_height_override=max_feat_height,
) )
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@ -11,8 +11,8 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
@ -108,33 +108,35 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
video_feature_size = frames_per_video * tokens_per_frame video_feature_size = frames_per_video * tokens_per_frame
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip( seq_data, ranges = dummy_seq_data_for_clip(
vision_config, vision_config,
seq_len, seq_len,
num_videos, num_videos,
image_token_id=hf_config.video_token_index, image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size, image_feature_size_override=video_feature_size,
mm_key="video",
) )
pil_frame = dummy_image_for_clip(vision_config, num_images=1) pil_frame = dummy_image_for_clip(vision_config, num_images=1)
np_frame = np.array(pil_frame["image"]) np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video} mm_data = {"video": mm_data_per_video}
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig): elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip( seq_data, ranges = dummy_seq_data_for_siglip(
vision_config, vision_config,
seq_len, seq_len,
num_videos, num_videos,
image_token_id=hf_config.video_token_index, image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size, image_feature_size_override=video_feature_size,
mm_key="video",
) )
pil_frame = dummy_image_for_siglip(vision_config, num_images=1) pil_frame = dummy_image_for_siglip(vision_config, num_images=1)
np_frame = np.array(pil_frame["image"]) np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video} mm_data = {"video": mm_data_per_video}
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
@ -145,6 +147,12 @@ def input_processor_for_llava_next_video(ctx: InputContext,
multi_modal_data = inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data: if multi_modal_data is None or "video" not in multi_modal_data:
return inputs return inputs
if "multi_modal_placeholders" in inputs and "video" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
video_data = multi_modal_data["video"] video_data = multi_modal_data["video"]
model_config = ctx.model_config model_config = ctx.model_config
@ -160,7 +168,7 @@ def input_processor_for_llava_next_video(ctx: InputContext,
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
@ -170,7 +178,8 @@ def input_processor_for_llava_next_video(ctx: InputContext,
return token_inputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
elif is_list_of(video_data, np.ndarray): elif is_list_of(video_data, np.ndarray):
raise NotImplementedError( raise NotImplementedError(

View File

@ -15,8 +15,8 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
@ -218,31 +218,31 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip( seq_data, ranges = dummy_seq_data_for_clip(
vision_config, vision_config,
seq_len, seq_len,
num_videos, num_videos,
image_token_id=hf_config.video_token_index, image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size, image_feature_size_override=video_feature_size,
) mm_key="video")
mm_data = dummy_video_for_clip(vision_config, mm_data = dummy_video_for_clip(vision_config,
num_frames=num_frames, num_frames=num_frames,
num_videos=num_videos) num_videos=num_videos)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig): elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip( seq_data, ranges = dummy_seq_data_for_siglip(
vision_config, vision_config,
seq_len, seq_len,
num_videos, num_videos,
image_token_id=hf_config.video_token_index, image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size, image_feature_size_override=video_feature_size,
) mm_key="video")
mm_data = dummy_video_for_siglip(vision_config, mm_data = dummy_video_for_siglip(vision_config,
num_frames=num_frames, num_frames=num_frames,
num_videos=num_videos) num_videos=num_videos)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
@ -320,7 +320,7 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
@ -330,7 +330,8 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
return token_inputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
elif is_list_of(video_data, np.ndarray): elif is_list_of(video_data, np.ndarray):
video_feature_size = [] video_feature_size = []

View File

@ -36,8 +36,8 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
@ -277,7 +277,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images) seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images) mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data)
def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs): def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):

View File

@ -36,7 +36,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
EncoderDecoderInputs, InputContext) EncoderDecoderInputs, InputContext)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -176,13 +176,14 @@ def dummy_image(num_images: int, ):
def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]): mm_counts: Mapping[str, int]):
num_images = mm_counts["image"] num_images = mm_counts["image"]
return dummy_decoder_seq_data(seq_len, num_images), None return DummyData(dummy_decoder_seq_data(seq_len, num_images))
def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int, def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]): mm_counts: Mapping[str, int]):
num_images = mm_counts["image"] num_images = mm_counts["image"]
return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images) return DummyData(dummy_encoder_seq_data(ctx, num_images),
dummy_image(num_images))
def _prepare_aspect_ratio_attention_mask( def _prepare_aspect_ratio_attention_mask(

View File

@ -7,8 +7,8 @@ from transformers import PaliGemmaConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
@ -58,7 +58,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
num_images = mm_counts["image"] num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_siglip( seq_data, ranges = dummy_seq_data_for_siglip(
vision_config, vision_config,
seq_len, seq_len,
num_images, num_images,
@ -66,7 +66,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
) )
mm_data = dummy_image_for_siglip(vision_config, num_images) mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
def input_processor_for_paligemma(ctx: InputContext, def input_processor_for_paligemma(ctx: InputContext,

View File

@ -28,8 +28,8 @@ from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig, from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig,
PoolerConfig) PoolerConfig)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -380,7 +380,7 @@ def dummy_data_for_phi3v(ctx: InputContext,
image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops)
seq_data = dummy_seq_data_for_clip( seq_data, ranges = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG, CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len, seq_len,
num_images, num_images,
@ -394,7 +394,7 @@ def dummy_data_for_phi3v(ctx: InputContext,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
) )
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
@lru_cache @lru_cache

View File

@ -17,8 +17,8 @@ from transformers.models.pixtral.modeling_pixtral import (
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -28,7 +28,8 @@ from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of from vllm.utils import is_list_of
@ -81,7 +82,12 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
) )
mm_data = {"image": num_images * [image]} mm_data = {"image": num_images * [image]}
return seq_data, mm_data mm_placeholders = {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
return DummyData(seq_data, mm_data, mm_placeholders)
def input_mapper_for_pixtral(ctx: InputContext, def input_mapper_for_pixtral(ctx: InputContext,
@ -636,7 +642,7 @@ def dummy_seq_data_for_pixtral_hf(
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[int] = None,
): mm_key: str = "image"):
if image_feature_size_override is None: if image_feature_size_override is None:
image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config) image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config)
else: else:
@ -645,7 +651,11 @@ def dummy_seq_data_for_pixtral_hf(
return SequenceData.from_prompt_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) ), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_pixtral_hf( def dummy_image_for_pixtral_hf(

View File

@ -23,8 +23,8 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -810,7 +810,7 @@ def dummy_data_for_qwen(
ctx: InputContext, ctx: InputContext,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> Tuple[SequenceData, Optional[Dict]]: ) -> DummyData:
"""Build dummy data for warming up Qwen models; this will only contain text """Build dummy data for warming up Qwen models; this will only contain text
matching the defaults for VLLM unless the model has a visual config. matching the defaults for VLLM unless the model has a visual config.
@ -829,7 +829,7 @@ def dummy_data_for_qwen(
if not hasattr(hf_config, "visual"): if not hasattr(hf_config, "visual"):
seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
mm_data = None mm_data = None
return seq_data, mm_data return DummyData(seq_data, mm_data)
# We have a visual component - use images to warm up # We have a visual component - use images to warm up
num_images = mm_counts["image"] num_images = mm_counts["image"]
@ -861,7 +861,7 @@ def dummy_data_for_qwen(
# the data will get resized and the # of tokens per image is constant # the data will get resized and the # of tokens per image is constant
image = Image.new("RGB", (224, 224), color=0) image = Image.new("RGB", (224, 224), color=0)
mm_data = {"image": image if num_images == 1 else [image] * num_images} mm_data = {"image": image if num_images == 1 else [image] * num_images}
return seq_data, mm_data return DummyData(seq_data, mm_data)
class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):

View File

@ -31,8 +31,8 @@ from transformers import Qwen2AudioConfig, Qwen2AudioEncoder
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
@ -85,7 +86,8 @@ class Qwen2AudioMultiModalProjector(nn.Module):
def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int, def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]): mm_counts: Mapping[str, int]):
num_audios = mm_counts["audio"] num_audios = mm_counts["audio"]
max_llm_audio_tokens = get_max_qwen2_audio_audio_tokens(ctx) * num_audios max_tokens_per_audio = get_max_qwen2_audio_audio_tokens(ctx)
max_llm_audio_tokens = max_tokens_per_audio * num_audios
if seq_len - max_llm_audio_tokens - 2 < 0: if seq_len - max_llm_audio_tokens - 2 < 0:
raise RuntimeError( raise RuntimeError(
f"Qwen2-Audio cannot process {num_audios} audios in a prompt, " f"Qwen2-Audio cannot process {num_audios} audios in a prompt, "
@ -99,7 +101,12 @@ def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int,
(0, seq_len - max_llm_audio_tokens), (0, seq_len - max_llm_audio_tokens),
) )
dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.) dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
return dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios} return DummyData(
dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, {
"audio":
consecutive_placeholder_ranges(num_items=num_audios,
item_size=max_tokens_per_audio)
})
def get_processor( def get_processor(

View File

@ -44,8 +44,8 @@ from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.activation import QuickGELU
@ -744,9 +744,10 @@ def dummy_data_for_qwen2_vl(
dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), dummy_image = Image.new("RGB", (max_resized_width, max_resized_height),
color=0) color=0)
return dummy_seqdata, { return DummyData(dummy_seqdata, {
"image": dummy_image if num_images == 1 else [dummy_image] * num_images "image":
} dummy_image if num_images == 1 else [dummy_image] * num_images
})
def _get_llm_num_vision_tokens( def _get_llm_num_vision_tokens(

View File

@ -23,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
@ -61,6 +62,7 @@ def dummy_seq_data_for_siglip(
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[int] = None,
mm_key: str = "image",
): ):
if image_feature_size_override is None: if image_feature_size_override is None:
image_feature_size = get_siglip_image_feature_size(hf_config) image_feature_size = get_siglip_image_feature_size(hf_config)
@ -70,7 +72,11 @@ def dummy_seq_data_for_siglip(
return SequenceData.from_prompt_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) ), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_siglip( def dummy_image_for_siglip(
@ -122,6 +128,11 @@ def input_processor_for_siglip(
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return inputs return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None: if image_feature_size_override is None:
@ -135,7 +146,7 @@ def input_processor_for_siglip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
@ -144,11 +155,10 @@ def input_processor_for_siglip(
) )
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return token_inputs( return token_inputs(prompt_token_ids=new_token_ids,
prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
) multi_modal_placeholders={"image": ranges})
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa

View File

@ -2,7 +2,6 @@
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import math import math
from array import array
from functools import cached_property, lru_cache from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union, cast) TypedDict, Union, cast)
@ -17,27 +16,27 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
from vllm.inputs.data import DecoderOnlyInputs, token_inputs InputContext, token_inputs)
from vllm.inputs.registry import InputContext
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
from vllm.multimodal.base import MultiModalInputs, NestedTensors NestedTensors)
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import IntermediateTensors, SequenceData
SequenceData)
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, merge_multimodal_embeddings) init_vllm_registered_model,
merge_multimodal_embeddings_from_map)
_AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25 _AUDIO_TOKENS_PER_SECOND = 6.25
@ -46,13 +45,13 @@ _AUDIO_TOKENS_PER_SECOND = 6.25
class UltravoxAudioFeatureInputs(TypedDict): class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"] type: Literal["audio_features"]
data: NestedTensors data: NestedTensors
"""Shape: `(batch_size, num_audios, 80, M)""" """Shape: `(batch_size, num_audios, 80, M)`"""
class UltravoxAudioEmbeddingInputs(TypedDict): class UltravoxAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"] type: Literal["audio_embeds"]
data: NestedTensors data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`"""
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
@ -79,17 +78,16 @@ def dummy_seq_data_for_ultravox(
seq_len: int, seq_len: int,
audio_count: int, audio_count: int,
): ):
audio_placeholder = array( audio_length = min(get_ultravox_max_audio_tokens(ctx),
VLLM_TOKEN_ID_ARRAY_TYPE, seq_len // audio_count)
[_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
# Add a separator between each chunk. return SequenceData.from_prompt_token_counts(
audio_token_ids = (audio_placeholder + (_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count),
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count (0, seq_len - audio_length * audio_count)), {
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, "audio":
[0]) * (seq_len - len(audio_token_ids)) consecutive_placeholder_ranges(num_items=audio_count,
item_size=audio_length)
return SequenceData(audio_token_ids + other_token_ids) }
def dummy_audio_for_ultravox( def dummy_audio_for_ultravox(
@ -107,10 +105,10 @@ def dummy_data_for_ultravox(
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
): ):
audio_count = mm_counts["audio"] audio_count = mm_counts["audio"]
seq_data = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count) seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count) mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
return (seq_data, mm_dict) return DummyData(seq_data, mm_dict, ranges)
def input_mapper_for_ultravox(ctx: InputContext, data: object): def input_mapper_for_ultravox(ctx: InputContext, data: object):
@ -164,6 +162,11 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
if multi_modal_data is None or "audio" not in multi_modal_data: if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs return inputs
if "multi_modal_placeholders" in inputs and "audio" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
feature_extractor = whisper_feature_extractor(ctx) feature_extractor = whisper_feature_extractor(ctx)
audios = multi_modal_data["audio"] audios = multi_modal_data["audio"]
if not isinstance(audios, list): if not isinstance(audios, list):
@ -197,7 +200,7 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
@ -208,7 +211,8 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
multi_modal_placeholders={"audio": ranges})
class StackAudioFrames(nn.Module): class StackAudioFrames(nn.Module):
@ -472,9 +476,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
inputs_embeds = merge_multimodal_embeddings( merge_multimodal_embeddings_from_map(
input_ids, inputs_embeds, audio_embeddings, inputs_embeds, audio_embeddings,
_AUDIO_PLACEHOLDER_TOKEN) attn_metadata.multi_modal_placeholder_index_maps["audio"])
input_ids = None input_ids = None
else: else:
inputs_embeds = None inputs_embeds = None

View File

@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
@ -326,6 +326,22 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
_embedding_count_expression(inner) for inner in embeddings) _embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings_from_map(
inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
placeholder map .
Note:
This updates ``inputs_embeds`` in place.
"""
flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
inputs_embeds[placeholder_map.dest] = flattened_embeddings[
placeholder_map.src]
return inputs_embeds
def _merge_multimodal_embeddings( def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor, is_multimodal: torch.Tensor,

View File

@ -1,6 +1,7 @@
from .base import (BatchedTensorInputs, MultiModalDataBuiltins, from .base import (BatchedTensorInputs, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalInputs, MultiModalPlugin, MultiModalDataDict, MultiModalInputs,
NestedTensors) MultiModalPlaceholderDict, MultiModalPlaceholderMap,
MultiModalPlugin, NestedTensors)
from .registry import MultiModalRegistry from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry() MULTIMODAL_REGISTRY = MultiModalRegistry()
@ -17,6 +18,8 @@ __all__ = [
"MultiModalDataBuiltins", "MultiModalDataBuiltins",
"MultiModalDataDict", "MultiModalDataDict",
"MultiModalInputs", "MultiModalInputs",
"MultiModalPlaceholderDict",
"MultiModalPlaceholderMap",
"MultiModalPlugin", "MultiModalPlugin",
"NestedTensors", "NestedTensors",
"MULTIMODAL_REGISTRY", "MULTIMODAL_REGISTRY",

View File

@ -1,8 +1,9 @@
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from typing import (Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
TypedDict, TypeVar, Union, cast, final) NamedTuple, Optional, Tuple, Type, TypedDict, TypeVar,
Union, cast, final)
import numpy as np import numpy as np
import torch import torch
@ -11,12 +12,15 @@ from PIL import Image
from torch import nn from torch import nn
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from vllm.config import ModelConfig
from vllm.inputs import InputContext from vllm.inputs import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
json_map_leaves, resolve_mm_processor_kwargs) json_map_leaves, resolve_mm_processor_kwargs)
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.sequence import SequenceGroupMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
@ -151,6 +155,30 @@ Note:
Read more on that :ref:`here <adding_multimodal_plugin>`. Read more on that :ref:`here <adding_multimodal_plugin>`.
""" """
class PlaceholderRange(TypedDict):
"""
Placeholder location information for multi-modal data.
For example:
Prompt: AAAA BBBB What is in these images?
Images A and B will have:
A: { "offset": 0, "length": 4 }
B: { "offset": 5, "length": 4 }
"""
offset: int
"""The start index of the placeholder in the prompt."""
length: int
"""The length of the placeholder."""
MultiModalPlaceholderDict = Mapping[str, List[PlaceholderRange]]
"""
A dictionary containing placeholder ranges.
"""
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
MultiModalInputs] MultiModalInputs]
""" """
@ -243,7 +271,7 @@ class MultiModalPlugin(ABC):
return wrapper return wrapper
def map_input(self, model_config: ModelConfig, def map_input(self, model_config: "ModelConfig",
data: MultiModalData[object], data: MultiModalData[object],
mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs: mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs:
""" """
@ -332,7 +360,7 @@ class MultiModalPlugin(ABC):
return wrapper return wrapper
def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
""" """
Get the maximum number of multi-modal tokens Get the maximum number of multi-modal tokens
for profiling the memory usage of a model. for profiling the memory usage of a model.
@ -366,3 +394,179 @@ class MultiModalPlugin(ABC):
self._validate_max_multimodal_tokens(max_mm_tokens) self._validate_max_multimodal_tokens(max_mm_tokens)
return max_mm_tokens return max_mm_tokens
class MultiModalPlaceholderMap:
"""
Relates multi-modal embeddings to their corresponding placeholders.
"""
class IndexMap(NamedTuple):
src: List[int]
dest: List[int]
src_ranges: List[range]
"""
The indices of the multi-modal embeddings that will replace the
corresponding placeholder embeddings pointed to by ``dest_ranges``.
"""
src_len: int
"""
The total number of flattened multi-modal embeddings.
"""
dest_ranges: List[range]
"""
The indices of the placeholder embeddings that will be replaced by the
multimodal embeddings.
"""
dest_len: int
"""
The total number of embeddings in the destination tensor.
"""
def __init__(self):
self.src_ranges = []
self.src_len = 0
self.dest_ranges = []
self.dest_len = 0
@classmethod
def from_seq_group(
cls, seq_group: "SequenceGroupMetadata", positions: range
) -> Tuple[Optional[MultiModalDataDict], Dict[str,
"MultiModalPlaceholderMap"]]:
"""
Returns the multi-modal items that intersect with the portion of a
prompt (``seq_group``) represented by ``positions``, as well as a
``MultiModalPlaceholderMap`` that relates the multi-modal embedding
vectors to their corresponding placeholders.
Consider the following scenarios:
Prompt: |AAAA BBBB What's in these images?|
Positions: |.................................|
images = [A, B]
src_ranges = [(0, 4), (4, 8)]
dest_ranges = [(0, 4), (5, 9)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ..... |
images = [A, B]
src_ranges = [(2, 4), (4, 6)]
dest_ranges = [(0, 2), (3, 5)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ......... |
images = [B]
src_ranges = [(0, 4)]
dest_ranges = [(0, 4)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | .......................|
images = []
src_ranges = []
dest_ranges = []
"""
if (not seq_group.multi_modal_data
or not seq_group.multi_modal_placeholders):
return seq_group.multi_modal_data, {}
mm_data = {**seq_group.multi_modal_data}
placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict(
MultiModalPlaceholderMap)
for modality, placeholders in seq_group.multi_modal_placeholders.items(
):
mm_items = mm_data.pop(modality)
if not isinstance(mm_items, list):
mm_items = [mm_items]
if positions:
intersecting_items = placeholder_maps[
modality].append_items_from_seq_group(
positions, mm_items, placeholders)
if intersecting_items:
mm_data[modality] = intersecting_items
return mm_data, placeholder_maps
def append_items_from_seq_group(
self, positions: range, multi_modal_items: List[_T],
multi_modal_placeholders: List[PlaceholderRange]) -> List[_T]:
"""
Adds the multi-modal items that intersect ```positions`` to this
placeholder map and returns the intersecting items.
"""
intersecting_items = []
if len(multi_modal_items) != len(multi_modal_placeholders):
raise ValueError(
"Multi-modal placeholders and items must have the same length."
)
for placeholder_dict, mm_item in zip(multi_modal_placeholders,
multi_modal_items):
placeholder = range(
placeholder_dict["offset"],
placeholder_dict["offset"] + placeholder_dict["length"])
intersection = range(max(positions.start, placeholder.start),
min(positions.stop, placeholder.stop))
if not intersection:
# Skip this multi-modal item.
continue
token_embedding_range = range(intersection.start - positions.start,
intersection.stop - positions.start)
multimodal_embedding_range = range(
intersection.start - placeholder.start + self.src_len,
intersection.stop - placeholder.start + self.src_len)
intersecting_items.append(mm_item)
self.dest_ranges.append(token_embedding_range)
self.src_ranges.append(multimodal_embedding_range)
self.src_len += len(placeholder)
self.dest_len += len(positions)
return intersecting_items
def extend(self, other: "MultiModalPlaceholderMap"):
"""
Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
instance based on the source and destination tensors being
concatenated.
"""
self.src_ranges.extend(
range(self.src_len + r.start, self.src_len + r.stop)
for r in other.src_ranges)
self.src_len += other.src_len
self.dest_ranges.extend(
range(self.dest_len + r.start, self.dest_len + r.stop)
for r in other.dest_ranges)
self.dest_len += other.dest_len
def index_map(self) -> "IndexMap":
"""
Finalizes the placeholder map into lists of indices that can be used to
index the source and destination tensors.
"""
src_indices = [i for r in self.src_ranges for i in r]
dest_indices = [i for r in self.dest_ranges for i in r]
if len(src_indices) != len(dest_indices):
raise ValueError(
f"The number of source ({len(src_indices)}) and destination "
f"indices ({len(dest_indices)}) must be the same.")
return MultiModalPlaceholderMap.IndexMap(src=src_indices,
dest=dest_indices)

View File

@ -1,11 +1,10 @@
from functools import lru_cache from functools import lru_cache
from typing import Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import torch import torch
from PIL import Image from PIL import Image
from transformers.image_processing_base import BatchFeature from transformers.image_processing_base import BatchFeature
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_image_processor from vllm.transformers_utils.processor import get_image_processor
@ -13,6 +12,9 @@ from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalInputs, MultiModalPlugin from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__) logger = init_logger(__name__)
cached_get_image_processor = lru_cache(get_image_processor) cached_get_image_processor = lru_cache(get_image_processor)
@ -26,7 +28,7 @@ class ImagePlugin(MultiModalPlugin):
def _get_hf_image_processor( def _get_hf_image_processor(
self, self,
model_config: ModelConfig, model_config: "ModelConfig",
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
): ):
if mm_processor_kwargs is None: if mm_processor_kwargs is None:

View File

@ -1,8 +1,7 @@
import functools import functools
from collections import UserDict from collections import UserDict
from typing import Any, Dict, Mapping, Optional, Sequence from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence
from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from .audio import AudioPlugin from .audio import AudioPlugin
@ -11,6 +10,9 @@ from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
from .image import ImagePlugin from .image import ImagePlugin
from .video import VideoPlugin from .video import VideoPlugin
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__) logger = init_logger(__name__)
@ -20,7 +22,7 @@ class _MultiModalLimits(UserDict):
when attempting to access a model that does not exist. when attempting to access a model that does not exist.
""" """
def __getitem__(self, key: ModelConfig) -> Dict[str, int]: def __getitem__(self, key: "ModelConfig") -> Dict[str, int]:
try: try:
return super().__getitem__(key) return super().__getitem__(key)
except KeyError as exc: except KeyError as exc:
@ -98,7 +100,7 @@ class MultiModalRegistry:
def map_input( def map_input(
self, self,
model_config: ModelConfig, model_config: "ModelConfig",
data: MultiModalDataDict, data: MultiModalDataDict,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
@ -139,7 +141,7 @@ class MultiModalRegistry:
return MultiModalInputs(merged_dict) return MultiModalInputs(merged_dict)
def create_input_mapper(self, model_config: ModelConfig): def create_input_mapper(self, model_config: "ModelConfig"):
""" """
Create an input mapper (see :meth:`map_input`) for a specific model. Create an input mapper (see :meth:`map_input`) for a specific model.
""" """
@ -177,7 +179,7 @@ class MultiModalRegistry:
""" """
return self.register_max_multimodal_tokens("image", max_mm_tokens) return self.register_max_multimodal_tokens("image", max_mm_tokens)
def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
""" """
Get the maximum number of multi-modal tokens Get the maximum number of multi-modal tokens
for profiling the memory usage of a model. for profiling the memory usage of a model.
@ -195,7 +197,7 @@ class MultiModalRegistry:
def init_mm_limits_per_prompt( def init_mm_limits_per_prompt(
self, self,
model_config: ModelConfig, model_config: "ModelConfig",
) -> None: ) -> None:
""" """
Initialize the maximum number of multi-modal input instances for each Initialize the maximum number of multi-modal input instances for each
@ -231,7 +233,7 @@ class MultiModalRegistry:
def get_mm_limits_per_prompt( def get_mm_limits_per_prompt(
self, self,
model_config: ModelConfig, model_config: "ModelConfig",
) -> Mapping[str, int]: ) -> Mapping[str, int]:
""" """
Get the maximum number of multi-modal input instances for each modality Get the maximum number of multi-modal input instances for each modality

View File

@ -10,7 +10,7 @@ from PIL import Image
from vllm.connections import global_http_connection from vllm.connections import global_http_connection
from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal.base import MultiModalDataDict from vllm.multimodal.base import MultiModalDataDict, PlaceholderRange
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
@ -258,7 +258,7 @@ def repeat_and_pad_placeholder_tokens(
repeat_count: Union[int, List[int]], repeat_count: Union[int, List[int]],
pad_token_left: Optional[int] = None, pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None, pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]: ) -> Tuple[Optional[str], List[int], List[PlaceholderRange]]:
if isinstance(repeat_count, int): if isinstance(repeat_count, int):
repeat_count = [repeat_count] repeat_count = [repeat_count]
@ -301,6 +301,7 @@ def repeat_and_pad_placeholder_tokens(
new_prompt += prompt_parts[-1] new_prompt += prompt_parts[-1]
new_token_ids: List[int] = [] new_token_ids: List[int] = []
placeholder_ranges: List[PlaceholderRange] = []
placeholder_token_idx = 0 placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids): for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id: if token == placeholder_token_id:
@ -310,6 +311,10 @@ def repeat_and_pad_placeholder_tokens(
pad_token_left=pad_token_left, pad_token_left=pad_token_left,
pad_token_right=pad_token_right, pad_token_right=pad_token_right,
) )
placeholder_ranges.append({
"offset": len(new_token_ids),
"length": len(replacement_ids)
})
new_token_ids.extend(replacement_ids) new_token_ids.extend(replacement_ids)
placeholder_token_idx += 1 placeholder_token_idx += 1
@ -320,4 +325,14 @@ def repeat_and_pad_placeholder_tokens(
else: else:
new_token_ids.append(token) new_token_ids.append(token)
return new_prompt, new_token_ids return new_prompt, new_token_ids, placeholder_ranges
def consecutive_placeholder_ranges(num_items: int,
item_size: int) -> List[PlaceholderRange]:
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
return [
PlaceholderRange(offset=i * item_size, length=item_size)
for i in range(num_items)
]

View File

@ -1,18 +1,19 @@
from functools import lru_cache from functools import lru_cache
from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import numpy as np import numpy as np
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalInputs from .base import MultiModalData, MultiModalInputs
from .image import ImagePlugin from .image import ImagePlugin
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__) logger = init_logger(__name__)
cached_get_video_processor = lru_cache(get_video_processor) cached_get_video_processor = lru_cache(get_video_processor)
@ -38,7 +39,7 @@ class VideoPlugin(ImagePlugin):
def _get_hf_video_processor( def _get_hf_video_processor(
self, self,
model_config: ModelConfig, model_config: "ModelConfig",
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
): ):
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
@ -56,7 +57,10 @@ class VideoPlugin(ImagePlugin):
) -> MultiModalInputs: ) -> MultiModalInputs:
model_config = ctx.model_config model_config = ctx.model_config
if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray): if isinstance(data, list) and len(data) == 1:
data = data[0]
if isinstance(data, np.ndarray):
video_processor = self._get_hf_video_processor( video_processor = self._get_hf_video_processor(
model_config, model_config,
mm_processor_kwargs, mm_processor_kwargs,

View File

@ -15,13 +15,13 @@ import torch
from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.inputs import SingletonInputs from vllm.inputs import SingletonInputs
from vllm.multimodal.base import MultiModalDataDict
VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@ -485,7 +485,7 @@ class Sequence:
return cast(List[int], self.inputs.get(prompt_token_ids_key)) return cast(List[int], self.inputs.get(prompt_token_ids_key))
@property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> MultiModalDataDict:
inputs = self.inputs inputs = self.inputs
if (inputs.get("multi_modal_data") if (inputs.get("multi_modal_data")
@ -495,11 +495,15 @@ class Sequence:
) )
return cast( return cast(
"MultiModalDataDict", MultiModalDataDict,
(inputs.get("multi_modal_data") (inputs.get("multi_modal_data")
or inputs.get("encoder_multi_modal_data") or {}), or inputs.get("encoder_multi_modal_data") or {}),
) )
@property
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
return self.inputs.get("multi_modal_placeholders") or {}
@property @property
def mm_processor_kwargs(self) -> Dict[str, Any]: def mm_processor_kwargs(self) -> Dict[str, Any]:
return self.inputs.get("mm_processor_kwargs") or {} return self.inputs.get("mm_processor_kwargs") or {}
@ -728,9 +732,13 @@ class SequenceGroup:
if self.encoder_seq is not None else None) if self.encoder_seq is not None else None)
@property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> MultiModalDataDict:
return self.first_seq.multi_modal_data return self.first_seq.multi_modal_data
@property
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
return self.first_seq.multi_modal_placeholders
@property @property
def mm_processor_kwargs(self) -> Dict[str, Any]: def mm_processor_kwargs(self) -> Dict[str, Any]:
return self.first_seq.mm_processor_kwargs return self.first_seq.mm_processor_kwargs
@ -946,6 +954,7 @@ class SequenceGroupMetadata(
# "MultiModalDataDict" types. We have to use Any due to msgspec # "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts. # doesn't allow to have union of 2 different dicts.
multi_modal_data: Optional[Any] = None multi_modal_data: Optional[Any] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
encoder_seq_data: Optional[SequenceData] = None encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[List[int]] = None cross_block_table: Optional[List[int]] = None

View File

@ -1,5 +1,6 @@
import dataclasses import dataclasses
import weakref import weakref
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
@ -16,7 +17,7 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs, MultiModalPlaceholderMap)
from vllm.sequence import (IntermediateTensors, SequenceData, from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
@ -148,9 +149,18 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
query_lens=seq_lens, query_lens=seq_lens,
) )
def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata,
computed_len: int, seq_data: SequenceData, computed_len: int,
mm_processor_kwargs: Dict[str, Any]): mm_processor_kwargs: Dict[str, Any]):
# NOTE: mm_data only includes the subset of multi-modal items that
# intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group, range(computed_len, len(seq_data.get_token_ids())))
if not mm_data:
return
mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs) mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs)
# special processing for mrope position deltas. # special processing for mrope position deltas.
@ -179,7 +189,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
context_len=computed_len, context_len=computed_len,
) )
seq_data.mrope_position_delta = mrope_position_delta seq_data.mrope_position_delta = mrope_position_delta
return mm_kwargs, mrope_positions return mm_kwargs, placeholder_maps, mrope_positions
def _prepare_prompt( def _prepare_prompt(
self, self,
@ -194,6 +204,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
slot_mapping: List[int] = [] slot_mapping: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = [] multi_modal_inputs_list: List[MultiModalInputs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
@ -210,11 +223,15 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
input_tokens.extend(prompt_tokens) # Token ids input_tokens.extend(prompt_tokens) # Token ids
mrope_positions = None mrope_positions = None
if (mm_data := seq_group_metadata.multi_modal_data): if seq_group_metadata.multi_modal_data:
mm_kwargs, mrope_positions = self._compute_multi_modal_input( mm_kwargs, placeholder_maps, mrope_positions = self \
seq_data, mm_data, computed_len, ._compute_multi_modal_input(
seq_group_metadata, seq_data, computed_len,
seq_group_metadata.mm_processor_kwargs) seq_group_metadata.mm_processor_kwargs)
multi_modal_inputs_list.append(mm_kwargs) multi_modal_inputs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map)
# Token position ids # Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
@ -264,6 +281,11 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
slot_mapping = torch.tensor(slot_mapping, slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long, dtype=torch.long,
device=self.device) # type: ignore device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
@ -275,6 +297,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
num_decode_tokens=0, num_decode_tokens=0,
block_tables=torch.tensor([]), block_tables=torch.tensor([]),
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
@ -366,6 +389,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_decode_seq_len=max_decode_seq_len, max_decode_seq_len=max_decode_seq_len,

View File

@ -306,13 +306,12 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len batch_size += seq_len
decoder_seq_data, decoder_dummy_multi_modal_data \ decoder_dummy_data = self.input_registry \
= self.input_registry.dummy_data_for_profiling( .dummy_data_for_profiling(self.model_config,
self.model_config,
seq_len, seq_len,
self.mm_registry, self.mm_registry,
is_encoder_data=False) is_encoder_data=False)
encoder_seq_data, encoder_dummy_multi_modal_data \ encoder_dummy_data \
= self.input_registry.dummy_data_for_profiling( = self.input_registry.dummy_data_for_profiling(
self.model_config, self.model_config,
seq_len, seq_len,
@ -320,26 +319,31 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
is_encoder_data=True) is_encoder_data=True)
# Having more tokens is over-conservative but otherwise fine # Having more tokens is over-conservative but otherwise fine
assert len(decoder_seq_data.prompt_token_ids) >= seq_len, ( assert len(
decoder_dummy_data.seq_data.prompt_token_ids
) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, " f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(decoder_seq_data.prompt_token_ids)}") f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}"
)
assert decoder_dummy_multi_modal_data is None or \ assert decoder_dummy_data.multi_modal_data is None or \
encoder_dummy_multi_modal_data is None, ( encoder_dummy_data.multi_modal_data is None, (
"Multi-modal data can't be provided in both encoder and decoder" "Multi-modal data can't be provided in both encoder and decoder"
) )
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
is_prompt=True, is_prompt=True,
seq_data={group_id: decoder_seq_data}, seq_data={group_id: decoder_dummy_data.seq_data},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables=None, block_tables=None,
encoder_seq_data=encoder_seq_data, encoder_seq_data=encoder_dummy_data.seq_data,
cross_block_table=None, cross_block_table=None,
multi_modal_data=decoder_dummy_multi_modal_data multi_modal_data=decoder_dummy_data.multi_modal_data
or encoder_dummy_multi_modal_data, or encoder_dummy_data.multi_modal_data,
) multi_modal_placeholders=decoder_dummy_data.
multi_modal_placeholders
or encoder_dummy_data.multi_modal_placeholders)
seqs.append(seq) seqs.append(seq)
# Run the model with the dummy inputs. # Run the model with the dummy inputs.

View File

@ -40,7 +40,8 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalRegistry) MultiModalInputs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
@ -242,6 +243,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Multi-modal inputs. # Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None, multi_modal_inputs: Optional[MultiModalInputs] = None,
multi_modal_placeholder_maps: Optional[Dict[
str, MultiModalPlaceholderMap]] = None,
# Whether the prefix cache is hit (prefill only). # Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False, prefix_cache_hit: bool = False,
@ -361,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_inputs = multi_modal_inputs self.multi_modal_inputs = multi_modal_inputs
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
self.prefix_cache_hit = prefix_cache_hit self.prefix_cache_hit = prefix_cache_hit
self.n_seqs = len(self.seq_ids) self.n_seqs = len(self.seq_ids)
@ -635,7 +639,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
seq_group_metadata: SequenceGroupMetadata): seq_group_metadata: SequenceGroupMetadata):
"""If multi-modal data is given, add it to the input.""" """If multi-modal data is given, add it to the input."""
mm_data = seq_group_metadata.multi_modal_data # NOTE: mm_data only includes the subset of multi-modal items that
# intersect with the current prefill positions.
positions = inter_data.input_positions[0]
mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group_metadata,
range(positions[0], positions[0] + len(positions)))
if not mm_data: if not mm_data:
return return
@ -643,6 +652,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
mm_data, mm_data,
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs) mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs)
inter_data.multi_modal_inputs = mm_kwargs inter_data.multi_modal_inputs = mm_kwargs
inter_data.multi_modal_placeholder_maps = placeholder_maps
# special processing for mrope position deltas. # special processing for mrope position deltas.
if self.runner.model_is_mrope: if self.runner.model_is_mrope:
@ -1255,7 +1265,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len batch_size += seq_len
seq_data, dummy_multi_modal_data = self.input_registry \ dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config, .dummy_data_for_profiling(self.model_config,
seq_len, seq_len,
self.mm_registry) self.mm_registry)
@ -1263,12 +1273,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
is_prompt=True, is_prompt=True,
seq_data={group_id: seq_data}, seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables=None, block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id] lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None, if dummy_lora_requests_per_seq else None,
multi_modal_data=dummy_multi_modal_data, multi_modal_data=dummy_data.multi_modal_data,
multi_modal_placeholders=dummy_data.multi_modal_placeholders,
) )
seqs.append(seq) seqs.append(seq)

View File

@ -46,9 +46,8 @@ def _init_attn_metadata_from_tensor_dict(
# Extract the fields used to create AttentionMetadata. # Extract the fields used to create AttentionMetadata.
valid_attn_kwargs = {} valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()): for field in dataclasses.fields(attn_backend.get_metadata_cls()):
val = tensor_dict.pop(field.name, None) if field.name in tensor_dict:
if val is not None: valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
valid_attn_kwargs[field.name] = val
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
tensor_dict["attn_metadata"] = attn_metadata tensor_dict["attn_metadata"] = attn_metadata

View File

@ -1,4 +1,5 @@
from typing import List, NamedTuple, Optional, Tuple from collections import defaultdict
from typing import Dict, List, NamedTuple, Optional, Tuple
import openvino as ov import openvino as ov
import torch import torch
@ -14,7 +15,7 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.openvino import get_model from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs, MultiModalPlaceholderMap)
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
@ -115,6 +116,9 @@ class OpenVINOModelRunner:
past_lens: List[int] = [] past_lens: List[int] = []
query_lens: List[int] = [] query_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = [] multi_modal_inputs_list: List[MultiModalInputs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
subsequence_begins: List[int] = [] subsequence_begins: List[int] = []
block_indices: List[int] = [] block_indices: List[int] = []
@ -168,15 +172,6 @@ class OpenVINOModelRunner:
and self.sliding_window is None and self.sliding_window is None
and is_prompt) and is_prompt)
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
mm_processor_kwargs=seq_group_metadata.
mm_processor_kwargs,
)
multi_modal_inputs_list.append(mm_kwargs)
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
# TODO(sang): Combine chunked prefill and prefix caching by # TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size. # only allowing multiple of block_size chunk size.
@ -220,7 +215,8 @@ class OpenVINOModelRunner:
query_lens.append(query_len) query_lens.append(query_len)
input_tokens.extend(tokens) input_tokens.extend(tokens)
input_positions.extend(list(range(computed_len, seq_len))) positions_range = range(computed_len, seq_len)
input_positions.extend(list(positions_range))
past_lens.append(computed_len) past_lens.append(computed_len)
subsequence_begins.append(subsequence_begins[-1] + query_len) subsequence_begins.append(subsequence_begins[-1] + query_len)
@ -233,6 +229,22 @@ class OpenVINOModelRunner:
), "seq_len: {}, computed_len: {}, query_len: {}".format( ), "seq_len: {}, computed_len: {}, query_len: {}".format(
seq_len, computed_len, query_len) seq_len, computed_len, query_len)
if seq_group_metadata.multi_modal_data:
# NOTE: mm_data only includes the subset of multi-modal
# items that intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range)
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
mm_processor_kwargs=seq_group_metadata.
mm_processor_kwargs)
multi_modal_inputs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map, )
max_query_len = max(query_lens) max_query_len = max(query_lens)
assert max_query_len > 0, "query_lens: {}".format(query_lens) assert max_query_len > 0, "query_lens: {}".format(query_lens)
@ -261,12 +273,19 @@ class OpenVINOModelRunner:
max_context_len, dtype=torch.int32, max_context_len, dtype=torch.int32,
device=self.device) # type: ignore device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
attn_metadata = self.attn_backend.make_openvino_metadata( attn_metadata = self.attn_backend.make_openvino_metadata(
past_lens=past_lens_tensor, past_lens=past_lens_tensor,
subsequence_begins=subsequence_begins_tensor, subsequence_begins=subsequence_begins_tensor,
block_indices=block_indices_tensor, block_indices=block_indices_tensor,
block_indices_begins=block_indices_begins_tensor, block_indices_begins=block_indices_begins_tensor,
max_context_len=max_context_len_tensor, max_context_len=max_context_len_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)

View File

@ -184,6 +184,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_prefill_tokens=batch_size * seq_len, num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=None, block_tables=None,
context_lens=None, context_lens=None,
) )
@ -216,6 +217,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=batch_size * seq_len, num_decode_tokens=batch_size * seq_len,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=block_tables, block_tables=block_tables,
context_lens=context_lens, context_lens=context_lens,
) )
@ -360,6 +362,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_prefill_tokens=0, # NOTE: This is not used. num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=None, block_tables=None,
context_lens=None, context_lens=None,
) )
@ -429,6 +432,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=batch_size, num_decode_tokens=batch_size,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=block_tables, block_tables=block_tables,
context_lens=context_lens, context_lens=context_lens,
) )

View File

@ -1,6 +1,7 @@
import dataclasses import dataclasses
import time import time
import weakref import weakref
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, TypeVar) Type, TypeVar)
@ -19,7 +20,8 @@ from vllm.model_executor import SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalRegistry) MultiModalInputs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad
@ -161,6 +163,9 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
slot_mapping: List[int] = [] slot_mapping: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = [] multi_modal_inputs_list: List[MultiModalInputs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
@ -179,7 +184,21 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
# Token position ids # Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len))) positions_range = range(computed_len, seq_len)
input_positions.extend(list(positions_range))
if seq_group_metadata.multi_modal_data:
# NOTE: mm_data only includes the subset of multi-modal items
# that intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range)
mm_kwargs = self.runner.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map)
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
@ -220,6 +239,11 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
slot_mapping = torch.tensor(slot_mapping, slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long, dtype=torch.long,
device=self.device) # type: ignore device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
max_seqlen = max(seq_lens) max_seqlen = max(seq_lens)
tmp = [0] tmp = [0]
@ -230,6 +254,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
seq_lens=seq_lens, seq_lens=seq_lens,
seqlen_q=seqlen_q, seqlen_q=seqlen_q,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
@ -313,6 +338,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
seq_lens=seq_lens, seq_lens=seq_lens,
seqlen_q=torch.tensor([]), seqlen_q=torch.tensor([]),
max_seqlen=0, max_seqlen=0,
@ -450,7 +476,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len batch_size += seq_len
seq_data, dummy_multi_modal_data = self.input_registry \ dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config, .dummy_data_for_profiling(self.model_config,
seq_len, seq_len,
self.mm_registry) self.mm_registry)
@ -458,12 +484,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
is_prompt=True, is_prompt=True,
seq_data={group_id: seq_data}, seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables=None, block_tables=None,
lora_request=None, lora_request=None,
multi_modal_data=dummy_multi_modal_data, multi_modal_data=dummy_data.multi_modal_data,
) multi_modal_placeholders=dummy_data.multi_modal_placeholders)
seqs.append(seq) seqs.append(seq)
# Run the model with the dummy inputs. # Run the model with the dummy inputs.