[1/N] Initial prototype for multi-modal processor (#10044)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
bb7991aa29
commit
0b8bb86bf1
@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i
|
||||
3. Register maximum number of multi-modal tokens
|
||||
------------------------------------------------
|
||||
|
||||
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance
|
||||
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item
|
||||
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.
|
||||
|
||||
.. code-block:: diff
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from vllm.inputs import InputContext, token_inputs
|
||||
from vllm.multimodal.base import MultiModalKwargs
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
|
||||
from .....conftest import IMAGE_ASSETS
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.base import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
|
||||
|
||||
def assert_nested_tensors_equal(expected: NestedTensors,
|
@ -1,12 +1,12 @@
|
||||
from array import array
|
||||
from typing import Mapping
|
||||
from typing import Callable, Dict, Mapping, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
|
||||
InputRegistry, token_inputs)
|
||||
InputRegistry, ProcessorInputs, token_inputs)
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
@ -34,10 +34,9 @@ def use_processor_mock():
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
num_crops=DEFAULT_NUM_CROPS):
|
||||
# For testing purposes, we don't worry about the llm inputs / return
|
||||
# type validation, and just return the value of the kwarg that we
|
||||
# clobber.
|
||||
return num_crops
|
||||
# For testing purposes, we don't worry about the prompt
|
||||
return token_inputs(prompt_token_ids=[],
|
||||
mm_processor_kwargs={"num_crops": num_crops})
|
||||
|
||||
with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
|
||||
return_value=custom_processor):
|
||||
@ -109,6 +108,21 @@ def _get_num_crops_info(init_num_crops: int, inference_num_crops: int):
|
||||
return init_kwargs, inference_kwargs, expected_seq_count
|
||||
|
||||
|
||||
def _get_processed_num_crops(
|
||||
processor: Callable[[ProcessorInputs], ProcessorInputs],
|
||||
inference_kwargs: Optional[Dict[str, int]],
|
||||
) -> int:
|
||||
processed_inputs = processor(
|
||||
token_inputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=inference_kwargs))
|
||||
|
||||
assert "type" in processed_inputs
|
||||
assert processed_inputs["type"] == "token"
|
||||
assert "mm_processor_kwargs" in processed_inputs
|
||||
return processed_inputs["mm_processor_kwargs"]["num_crops"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
|
||||
(None, None),
|
||||
(NUM_CROPS_OVERRIDE, None),
|
||||
@ -124,10 +138,8 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops,
|
||||
|
||||
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
num_crops_val = processor(
|
||||
token_inputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=inference_kwargs))
|
||||
num_crops_val = _get_processed_num_crops(processor, inference_kwargs)
|
||||
|
||||
assert num_crops_val == expected_seq_count
|
||||
|
||||
|
||||
@ -153,10 +165,7 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
|
||||
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
# Should filter out the inference time kwargs
|
||||
num_crops_val = processor(
|
||||
token_inputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=mm_processor_kwargs))
|
||||
num_crops_val = _get_processed_num_crops(processor, mm_processor_kwargs)
|
||||
assert num_crops_val == DEFAULT_NUM_CROPS
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Compare the with and without prefix caching."""
|
||||
from vllm.inputs import DecoderOnlyInputs
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import hash_block_tokens
|
||||
@ -8,7 +8,7 @@ from vllm.v1.core.kv_cache_utils import hash_block_tokens
|
||||
def make_request(request_id, prompt_token_ids):
|
||||
return Request(
|
||||
request_id=request_id,
|
||||
inputs=DecoderOnlyInputs(prompt_token_ids=prompt_token_ids),
|
||||
inputs=token_inputs(prompt_token_ids=prompt_token_ids),
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
eos_token_id=100,
|
||||
arrival_time=0,
|
||||
|
@ -107,7 +107,7 @@ class ModelConfig:
|
||||
matches the model name exposed via the APIs. If multiple model
|
||||
names provided, the first name will be used. If not specified,
|
||||
the model name will be the same as `model`.
|
||||
limit_mm_per_prompt: Maximum number of data instances per modality
|
||||
limit_mm_per_prompt: Maximum number of data items per modality
|
||||
per prompt. Only applicable for multimodal models.
|
||||
override_neuron_config: Initialize non default neuron config or
|
||||
override default neuron config that are specific to Neuron devices,
|
||||
|
@ -19,6 +19,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
@ -729,6 +730,9 @@ class AsyncLLMEngine(EngineClient):
|
||||
self.set_errored(exc)
|
||||
self._request_tracker.propagate_exception(exc)
|
||||
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
return self.engine.input_preprocessor
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
|
@ -30,7 +30,7 @@ from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
PromptType)
|
||||
PromptType, SingletonInputsAdapter)
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
@ -39,6 +39,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_local_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -226,6 +227,7 @@ class LLMEngine:
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
) -> None:
|
||||
|
||||
@ -335,7 +337,8 @@ class LLMEngine:
|
||||
model_config)
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(model_config,
|
||||
self.tokenizer)
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
|
||||
self.input_registry = input_registry
|
||||
self.input_processor = input_registry.create_input_processor(
|
||||
@ -851,13 +854,6 @@ class LLMEngine:
|
||||
)
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
|
||||
# This is a bit of a hack - copy the mm_processor_kwargs that were
|
||||
# used in the input processor to the processed output, since these
|
||||
# kwargs are presumed to be immutable and the values should be aligned
|
||||
# between the input processor (here) and the input mapper.
|
||||
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
|
||||
"mm_processor_kwargs")
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
@ -2019,7 +2015,7 @@ class LLMEngine:
|
||||
else:
|
||||
prompt_inputs = inputs
|
||||
|
||||
prompt_ids = prompt_inputs.get("prompt_token_ids")
|
||||
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
|
||||
|
||||
if prompt_ids is None or len(prompt_ids) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
@ -31,6 +31,7 @@ from vllm.engine.protocol import EngineClient
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -94,6 +95,8 @@ class MQLLMEngineClient(EngineClient):
|
||||
parallel_config=engine_config.parallel_config,
|
||||
enable_lora=bool(engine_config.lora_config),
|
||||
)
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.tokenizer)
|
||||
|
||||
# Send RPCGenerateRequest to the MQLLMEngine.
|
||||
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
|
||||
@ -345,6 +348,9 @@ class MQLLMEngineClient(EngineClient):
|
||||
or response != VLLM_RPC_SUCCESS_STR):
|
||||
raise ValueError(error_message)
|
||||
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
return self.input_preprocessor
|
||||
|
||||
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
|
||||
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
|
||||
|
||||
|
@ -62,7 +62,6 @@ class EngineClient(ABC):
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
model_config: ModelConfig,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
@ -74,13 +73,14 @@ class EngineClient(ABC):
|
||||
length_penalty = params.length_penalty
|
||||
include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
tokenizer = await self.get_tokenizer()
|
||||
input_preprocessor = InputPreprocessor(model_config, tokenizer)
|
||||
preprocessor = await self.get_input_preprocessor()
|
||||
tokenizer_group = preprocessor.get_tokenizer_group()
|
||||
tokenizer = await tokenizer_group.get_lora_tokenizer_async()
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
processed_inputs = input_preprocessor._prompt_to_llm_inputs(
|
||||
processed_inputs = preprocessor._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
@ -220,6 +220,7 @@ class EngineClient(ABC):
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
@ -228,8 +229,13 @@ class EngineClient(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
...
|
||||
"""Get the decoding configuration of the vLLM engine."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
"""Get the input processor of the vLLM engine."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_tokenizer(
|
||||
|
@ -190,7 +190,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt=engine_prompt,
|
||||
model_config=self.model_config,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
)
|
||||
|
@ -140,7 +140,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt=engine_prompt,
|
||||
model_config=self.model_config,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
)
|
||||
|
@ -1,9 +1,11 @@
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
|
||||
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
|
||||
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
|
||||
TokensPrompt, build_explicit_enc_dec_prompt,
|
||||
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
|
||||
from .registry import DummyData, InputContext, InputRegistry
|
||||
SingletonInputs, SingletonInputsAdapter, SingletonPrompt,
|
||||
TextPrompt, TokenInputs, TokensPrompt,
|
||||
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
|
||||
token_inputs, zip_enc_dec_prompts)
|
||||
from .registry import (DummyData, InputContext, InputProcessingContext,
|
||||
InputRegistry)
|
||||
|
||||
INPUT_REGISTRY = InputRegistry()
|
||||
"""
|
||||
@ -26,12 +28,14 @@ __all__ = [
|
||||
"EncoderDecoderInputs",
|
||||
"ProcessorInputs",
|
||||
"SingletonInputs",
|
||||
"SingletonInputsAdapter",
|
||||
"build_explicit_enc_dec_prompt",
|
||||
"to_enc_dec_tuple_list",
|
||||
"zip_enc_dec_prompts",
|
||||
"INPUT_REGISTRY",
|
||||
"DummyData",
|
||||
"InputContext",
|
||||
"InputProcessingContext",
|
||||
"InputRegistry",
|
||||
]
|
||||
|
||||
|
@ -1,10 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal,
|
||||
Optional, Tuple, Union, cast)
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict, TypeVar
|
||||
import torch
|
||||
from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
|
||||
from vllm.multimodal.inputs import MultiModalInputsV2
|
||||
|
||||
|
||||
class TextPrompt(TypedDict):
|
||||
@ -36,13 +40,13 @@ class TokensPrompt(TypedDict):
|
||||
|
||||
multi_modal_data: NotRequired["MultiModalDataDict"]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
DEPRECATED: Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
mm_processor_kwargs: NotRequired[Dict[str, Any]]
|
||||
"""
|
||||
Optional multi-modal processor kwargs to be forwarded to the
|
||||
DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the
|
||||
multimodal input mapper & processor. Note that if multiple modalities
|
||||
have registered mappers etc for the model being considered, we attempt
|
||||
to pass the mm_processor_kwargs to each of them.
|
||||
@ -176,7 +180,7 @@ def token_inputs(
|
||||
return inputs
|
||||
|
||||
|
||||
DecoderOnlyInputs = TokenInputs
|
||||
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputsV2"]
|
||||
"""
|
||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||
passed to the model executor.
|
||||
@ -191,19 +195,91 @@ class EncoderDecoderInputs(TypedDict):
|
||||
|
||||
This specifies the required data for encoder-decoder models.
|
||||
"""
|
||||
encoder: TokenInputs
|
||||
encoder: Union[TokenInputs, "MultiModalInputsV2"]
|
||||
"""The inputs for the encoder portion."""
|
||||
|
||||
decoder: TokenInputs
|
||||
decoder: Union[TokenInputs, "MultiModalInputsV2"]
|
||||
"""The inputs for the decoder portion."""
|
||||
|
||||
|
||||
SingletonInputs = TokenInputs
|
||||
SingletonInputs = Union[TokenInputs, "MultiModalInputsV2"]
|
||||
"""
|
||||
A processed :class:`SingletonPrompt` which can be passed to
|
||||
:class:`vllm.sequence.Sequence`.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingletonInputsAdapter:
|
||||
"""
|
||||
Unified interface to access the components of :class:`SingletonInputs`.
|
||||
"""
|
||||
inputs: SingletonInputs
|
||||
|
||||
@cached_property
|
||||
def prompt(self) -> Optional[str]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token" or inputs["type"] == "multimodal":
|
||||
return inputs.get("prompt")
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token" or inputs["type"] == "multimodal":
|
||||
return inputs.get("prompt_token_ids", [])
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token" or inputs["type"] == "multimodal":
|
||||
return None
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("multi_modal_data", {})
|
||||
|
||||
if inputs["type"] == "multimodal":
|
||||
return inputs.get("mm_kwargs", {})
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("multi_modal_placeholders", {})
|
||||
|
||||
if inputs["type"] == "multimodal":
|
||||
return inputs.get("mm_placeholders", {})
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("mm_processor_kwargs", {})
|
||||
|
||||
if inputs["type"] == "multimodal":
|
||||
return {}
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
|
||||
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
|
||||
"""
|
||||
The inputs to :data:`vllm.inputs.InputProcessor`.
|
||||
@ -234,10 +310,11 @@ def zip_enc_dec_prompts(
|
||||
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
|
||||
"""
|
||||
Zip encoder and decoder prompts together into a list of
|
||||
:class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs
|
||||
may also be provided; if a dict is passed, the same dictionary will be
|
||||
used for every encoder/decoder prompt. If an iterable is provided, it will
|
||||
be zipped with the encoder/decoder prompts.
|
||||
:class:`ExplicitEncoderDecoderPrompt` instances.
|
||||
|
||||
``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
|
||||
dictionary will be used for every encoder/decoder prompt. If an iterable is
|
||||
provided, it will be zipped with the encoder/decoder prompts.
|
||||
"""
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = cast(Dict[str, Any], {})
|
||||
|
@ -1,11 +1,13 @@
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
from typing import List, Mapping, Optional, Union
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.utils import print_warning_once
|
||||
@ -23,11 +25,13 @@ class InputPreprocessor:
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: Optional[BaseTokenizerGroup],
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
self.mm_registry = mm_registry
|
||||
|
||||
def get_tokenizer_group(self) -> BaseTokenizerGroup:
|
||||
if self.tokenizer is None:
|
||||
@ -198,14 +202,79 @@ class InputPreprocessor:
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
|
||||
def _can_process_multimodal(self) -> bool:
|
||||
model_config = self.model_config
|
||||
|
||||
if not model_config.is_multimodal_model:
|
||||
raise ValueError("Your model does not support multi-modal inputs")
|
||||
|
||||
# Interim measure so we can handle models that have yet to be
|
||||
# updated to use the new multi-modal processor
|
||||
can_process_multimodal = self.mm_registry.has_processor(model_config)
|
||||
if not can_process_multimodal:
|
||||
logger.info(
|
||||
"Your model uses the legacy input pipeline instead of the new "
|
||||
"multi-modal processor. Please note that the legacy pipeline "
|
||||
"will be removed in a future release. For more details, see: "
|
||||
"https://github.com/vllm-project/vllm/issues/10114")
|
||||
|
||||
return can_process_multimodal
|
||||
|
||||
def _process_multimodal(
|
||||
self,
|
||||
prompt: Union[str, List[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Optional[Mapping[str, object]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> MultiModalInputsV2:
|
||||
"""
|
||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||
returning the corresponding token IDs and metadata.
|
||||
"""
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(
|
||||
self.model_config, tokenizer)
|
||||
|
||||
if isinstance(prompt, list):
|
||||
prompt = tokenizer.decode(prompt)
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
|
||||
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
|
||||
async def _process_multimodal_async(
|
||||
self,
|
||||
prompt: Union[str, List[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Optional[Mapping[str, object]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> MultiModalInputsV2:
|
||||
"""Async version of :meth:`_process_multimodal`."""
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request
|
||||
)
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(
|
||||
self.model_config, tokenizer)
|
||||
if isinstance(prompt, list):
|
||||
logger.warning("Passing `multi_modal_data` in TokensPrompt is"
|
||||
"deprecated and will be removed in a future update")
|
||||
prompt = tokenizer.decode(prompt)
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
|
||||
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
|
||||
def _prompt_to_llm_inputs(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> SingletonInputs:
|
||||
'''
|
||||
Extract the components of any single encoder or decoder input prompt.
|
||||
"""
|
||||
Extract the singleton inputs from a prompt.
|
||||
|
||||
Arguments:
|
||||
|
||||
@ -215,12 +284,8 @@ class InputPreprocessor:
|
||||
|
||||
Returns:
|
||||
|
||||
* prompt
|
||||
* prompt_token_ids
|
||||
* multi_modal_data
|
||||
* mm_processor_kwargs (request-level input processor/mapper overrides)
|
||||
'''
|
||||
|
||||
* :class:`SingletonInputs` instance
|
||||
"""
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
|
||||
if parsed["type"] == "str":
|
||||
@ -243,6 +308,14 @@ class InputPreprocessor:
|
||||
multi_modal_data = tokens_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
|
||||
|
||||
if multi_modal_data is not None and self._can_process_multimodal():
|
||||
return self._process_multimodal(
|
||||
prompt_token_ids,
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
@ -253,13 +326,22 @@ class InputPreprocessor:
|
||||
text_content = parsed["content"]
|
||||
|
||||
prompt_text = text_content["prompt"]
|
||||
multi_modal_data = text_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
|
||||
|
||||
if multi_modal_data is not None and self._can_process_multimodal():
|
||||
return self._process_multimodal(
|
||||
prompt_text,
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt_text,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = text_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
@ -299,6 +381,14 @@ class InputPreprocessor:
|
||||
multi_modal_data = tokens_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
|
||||
|
||||
if multi_modal_data is not None and self._can_process_multimodal():
|
||||
return await self._process_multimodal_async(
|
||||
prompt_token_ids,
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
@ -309,13 +399,22 @@ class InputPreprocessor:
|
||||
text_content = parsed["content"]
|
||||
|
||||
prompt_text = text_content["prompt"]
|
||||
multi_modal_data = text_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
|
||||
|
||||
if multi_modal_data is not None and self._can_process_multimodal():
|
||||
return await self._process_multimodal_async(
|
||||
prompt_text,
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt_text,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = text_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
@ -331,7 +430,8 @@ class InputPreprocessor:
|
||||
encoder_inputs: SingletonInputs,
|
||||
decoder_inputs: Optional[SingletonInputs],
|
||||
) -> EncoderDecoderInputs:
|
||||
if encoder_inputs["type"] == "token":
|
||||
if (encoder_inputs["type"] == "token"
|
||||
or encoder_inputs["type"] == "multimodal"):
|
||||
pass
|
||||
else:
|
||||
assert_never(encoder_inputs)
|
||||
@ -340,7 +440,8 @@ class InputPreprocessor:
|
||||
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
None)
|
||||
decoder_inputs = token_inputs(dec_token_ids)
|
||||
elif decoder_inputs["type"] == "token":
|
||||
elif (decoder_inputs["type"] == "token"
|
||||
or decoder_inputs["type"] == "multimodal"):
|
||||
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
decoder_inputs["prompt_token_ids"])
|
||||
decoder_inputs["prompt_token_ids"] = dec_token_ids
|
||||
@ -361,7 +462,7 @@ class InputPreprocessor:
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
) -> EncoderDecoderInputs:
|
||||
'''
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
Process an input prompt into an :class:`EncoderDecoderInputs` instance.
|
||||
|
||||
@ -391,8 +492,7 @@ class InputPreprocessor:
|
||||
Returns:
|
||||
|
||||
* :class:`EncoderDecoderInputs` instance
|
||||
'''
|
||||
|
||||
"""
|
||||
encoder_inputs: SingletonInputs
|
||||
decoder_inputs: Optional[SingletonInputs]
|
||||
|
||||
@ -460,7 +560,8 @@ class InputPreprocessor:
|
||||
prompt_inputs: DecoderOnlyInputs,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> DecoderOnlyInputs:
|
||||
if prompt_inputs["type"] == "token":
|
||||
if (prompt_inputs["type"] == "token"
|
||||
or prompt_inputs["type"] == "multimodal"):
|
||||
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
|
||||
prompt_inputs["prompt_token_ids"],
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
@ -477,7 +578,7 @@ class InputPreprocessor:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
'''
|
||||
"""
|
||||
For decoder-only models:
|
||||
Process an input prompt into an :class:`DecoderOnlyInputs` instance.
|
||||
|
||||
@ -491,7 +592,7 @@ class InputPreprocessor:
|
||||
Returns:
|
||||
|
||||
* :class:`DecoderOnlyInputs` instance
|
||||
'''
|
||||
"""
|
||||
|
||||
prompt_comps = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
|
@ -5,14 +5,17 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
|
||||
Optional, Protocol, Type, cast)
|
||||
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import TypeVar
|
||||
from transformers import PretrainedConfig, ProcessorMixin
|
||||
from typing_extensions import TypeVar, assert_never
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
|
||||
resolve_mm_processor_kwargs)
|
||||
|
||||
from .data import ProcessorInputs
|
||||
from .data import ProcessorInputs, SingletonInputs
|
||||
from .parse import is_encoder_decoder_inputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -61,6 +64,19 @@ class InputContext:
|
||||
return self.model_config.hf_image_processor_config
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputProcessingContext(InputContext):
|
||||
tokenizer: AnyTokenizer
|
||||
"""The tokenizer used to tokenize the inputs."""
|
||||
|
||||
def get_hf_processor(self) -> ProcessorMixin:
|
||||
return cached_get_processor(
|
||||
self.model_config.tokenizer,
|
||||
tokenizer=self.tokenizer, # Override the tokenizer with ours
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
|
||||
|
||||
@ -94,7 +110,7 @@ class DummyDataFactory(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class _MultiModalCounts(UserDict):
|
||||
class _MultiModalCounts(UserDict[str, int]):
|
||||
"""
|
||||
Wraps `mm_counts` for a more informative error message
|
||||
when attempting to access a plugin that does not exist.
|
||||
@ -287,6 +303,21 @@ class InputRegistry:
|
||||
return self._input_processors_by_model_type \
|
||||
.get(model_cls, self._default_input_processor)
|
||||
|
||||
def _ensure_mm_kwargs(
|
||||
self,
|
||||
inputs: SingletonInputs,
|
||||
mm_processor_kwargs: Dict[str, Any],
|
||||
):
|
||||
if inputs["type"] == "token":
|
||||
# In case the input processor for that model fails to set it
|
||||
if "mm_processor_kwargs" not in inputs:
|
||||
inputs["mm_processor_kwargs"] = mm_processor_kwargs
|
||||
elif inputs["type"] == "multimodal":
|
||||
# Be more strict in V2
|
||||
assert "mm_kwargs" in inputs
|
||||
else:
|
||||
assert_never(inputs["type"])
|
||||
|
||||
def process_input(self, model_config: "ModelConfig",
|
||||
inputs: ProcessorInputs) -> ProcessorInputs:
|
||||
"""
|
||||
@ -312,8 +343,21 @@ class InputRegistry:
|
||||
processor,
|
||||
)
|
||||
|
||||
return processor(InputContext(model_config), inputs,
|
||||
**mm_processor_kwargs)
|
||||
processed_inputs = processor(
|
||||
InputContext(model_config),
|
||||
inputs,
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
|
||||
if is_encoder_decoder_inputs(processed_inputs):
|
||||
self._ensure_mm_kwargs(processed_inputs["encoder"],
|
||||
mm_processor_kwargs)
|
||||
self._ensure_mm_kwargs(processed_inputs["decoder"],
|
||||
mm_processor_kwargs)
|
||||
else:
|
||||
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
|
||||
|
||||
return processed_inputs
|
||||
|
||||
def create_input_processor(self, model_config: "ModelConfig"):
|
||||
"""
|
||||
|
@ -30,8 +30,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.base import MultiModalData
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
|
@ -32,8 +32,7 @@ from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalKwargs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
consecutive_placeholder_ranges)
|
||||
|
@ -15,8 +15,7 @@ from transformers import PretrainedConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalKwargs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
|
@ -25,8 +25,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.models.intern_vit import (InternVisionModel,
|
||||
InternVisionPatchModel)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalKwargs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import NestedTensors
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
|
@ -51,8 +51,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.models.utils import LLMWrapper
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalKwargs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
@ -39,7 +39,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import NestedTensors, PlaceholderRange
|
||||
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
|
||||
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.utils import is_list_of
|
||||
|
@ -29,8 +29,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import merge_multimodal_embeddings
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalKwargs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
consecutive_placeholder_ranges)
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
@ -42,8 +42,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalKwargs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.utils import is_list_of
|
||||
|
@ -60,10 +60,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalKwargs)
|
||||
from vllm.multimodal.base import MultiModalData
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
|
||||
MultiModalKwargs)
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
|
@ -15,7 +15,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors
|
||||
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
@ -1,7 +1,8 @@
|
||||
from .base import (BatchedTensorInputs, MultiModalDataBuiltins,
|
||||
MultiModalDataDict, MultiModalKwargs,
|
||||
MultiModalPlaceholderDict, MultiModalPlaceholderMap,
|
||||
MultiModalPlugin, NestedTensors)
|
||||
from .base import MultiModalPlaceholderMap, MultiModalPlugin
|
||||
from .inputs import (BatchedTensorInputs, MultiModalData,
|
||||
MultiModalDataBuiltins, MultiModalDataDict,
|
||||
MultiModalKwargs, MultiModalPlaceholderDict,
|
||||
NestedTensors)
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||
@ -15,6 +16,7 @@ See also:
|
||||
|
||||
__all__ = [
|
||||
"BatchedTensorInputs",
|
||||
"MultiModalData",
|
||||
"MultiModalDataBuiltins",
|
||||
"MultiModalDataDict",
|
||||
"MultiModalKwargs",
|
||||
|
@ -1,5 +1,7 @@
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.multimodal.base import MultiModalKwargs, MultiModalPlugin
|
||||
|
||||
from .base import MultiModalPlugin
|
||||
from .inputs import AudioItem, MultiModalData, MultiModalKwargs
|
||||
|
||||
|
||||
class AudioPlugin(MultiModalPlugin):
|
||||
@ -8,8 +10,12 @@ class AudioPlugin(MultiModalPlugin):
|
||||
def get_data_key(self) -> str:
|
||||
return "audio"
|
||||
|
||||
def _default_input_mapper(self, ctx: InputContext, data: object,
|
||||
**mm_processor_kwargs) -> MultiModalKwargs:
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[AudioItem],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalKwargs:
|
||||
raise NotImplementedError("There is no default audio input mapper")
|
||||
|
||||
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||
|
@ -1,181 +1,24 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
|
||||
NamedTuple, Optional, Tuple, Type, TypedDict, TypeVar,
|
||||
Union, cast, final)
|
||||
from collections import defaultdict
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
|
||||
Optional, Sequence, Tuple, Type, TypeVar, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
|
||||
json_map_leaves, resolve_mm_processor_kwargs)
|
||||
from vllm.utils import (get_allowed_kwarg_only_overrides,
|
||||
resolve_mm_processor_kwargs)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
|
||||
from .inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs,
|
||||
PlaceholderRange)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
|
||||
"""
|
||||
Uses a list instead of a tensor if the dimensions of each element do not match.
|
||||
"""
|
||||
|
||||
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
|
||||
"""
|
||||
A dictionary containing nested tensors which have been batched via
|
||||
:meth:`MultiModalKwargs.batch`.
|
||||
"""
|
||||
|
||||
|
||||
class _MultiModalKwargsBase(UserDict[str, NestedTensors]):
|
||||
pass
|
||||
|
||||
|
||||
class MultiModalKwargs(_MultiModalKwargsBase):
|
||||
"""
|
||||
A dictionary that represents the keyword arguments to
|
||||
:meth:`~torch.nn.Module.forward`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
|
||||
"""
|
||||
Recursively stacks lists of tensors when they all have the same shape.
|
||||
"""
|
||||
if isinstance(nested_tensors, torch.Tensor):
|
||||
return nested_tensors
|
||||
|
||||
if isinstance(nested_tensors, np.ndarray):
|
||||
return torch.from_numpy(nested_tensors)
|
||||
|
||||
if isinstance(nested_tensors, (int, float)):
|
||||
return torch.tensor(nested_tensors)
|
||||
|
||||
stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
|
||||
if not is_list_of(stacked, torch.Tensor, check="all"):
|
||||
# Only tensors (not lists) can be stacked.
|
||||
return stacked
|
||||
|
||||
tensors_ = cast(List[torch.Tensor], stacked)
|
||||
if any(t.shape != tensors_[0].shape for t in tensors_):
|
||||
# The tensors have incompatible shapes and can't be stacked.
|
||||
return tensors_
|
||||
|
||||
return torch.stack(tensors_)
|
||||
|
||||
@staticmethod
|
||||
def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs:
|
||||
"""
|
||||
Batch multiple inputs together into a dictionary.
|
||||
|
||||
The resulting dictionary has the same keys as the inputs.
|
||||
If the corresponding value from each input is a tensor and they all
|
||||
share the same shape, the output value is a single batched tensor;
|
||||
otherwise, the output value is a list containing the original value
|
||||
from each input.
|
||||
"""
|
||||
if len(inputs_list) == 0:
|
||||
return {}
|
||||
|
||||
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
|
||||
|
||||
for inputs in inputs_list:
|
||||
# For models that supports multiple modalities (e.g. Qwen2-VL),
|
||||
# different modalities will return different data keys,
|
||||
# so batch() should skip the same key check.
|
||||
|
||||
for k, v in inputs.items():
|
||||
item_lists[k].append(v)
|
||||
|
||||
return {
|
||||
k: MultiModalKwargs._try_stack(item_list)
|
||||
for k, item_list in item_lists.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def as_kwargs(
|
||||
batched_inputs: BatchedTensorInputs,
|
||||
*,
|
||||
device: torch.types.Device,
|
||||
) -> BatchedTensorInputs:
|
||||
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
|
||||
|
||||
json_mapped = json_map_leaves(
|
||||
lambda x: x.to(device, non_blocking=True),
|
||||
json_inputs,
|
||||
)
|
||||
|
||||
return cast(BatchedTensorInputs, json_mapped)
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
MultiModalData: TypeAlias = Union[_T, List[_T]]
|
||||
"""
|
||||
Either a single data instance, or a list of data instances.
|
||||
|
||||
The number of data instances allowed per modality is restricted by
|
||||
`--limit-mm-per-prompt`.
|
||||
"""
|
||||
|
||||
|
||||
@final
|
||||
class MultiModalDataBuiltins(TypedDict, total=False):
|
||||
"""Modality types that are predefined by vLLM."""
|
||||
|
||||
image: MultiModalData[Image.Image]
|
||||
"""The input image(s)."""
|
||||
|
||||
audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
|
||||
"""The input audio item(s) and corresponding sampling rate(s)."""
|
||||
|
||||
video: MultiModalData[Tuple[np.ndarray]]
|
||||
"""The input video(s)."""
|
||||
|
||||
|
||||
MultiModalDataDict = Union[MultiModalDataBuiltins,
|
||||
Mapping[str, MultiModalData[object]]]
|
||||
"""
|
||||
A dictionary containing an item for each modality type to input.
|
||||
|
||||
Note:
|
||||
This dictionary also accepts modality keys defined outside
|
||||
:class:`MultiModalDataBuiltins` as long as a customized plugin is registered
|
||||
through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
Read more on that :ref:`here <adding_multimodal_plugin>`.
|
||||
"""
|
||||
|
||||
|
||||
class PlaceholderRange(TypedDict):
|
||||
"""
|
||||
Placeholder location information for multi-modal data.
|
||||
|
||||
For example:
|
||||
Prompt: AAAA BBBB What is in these images?
|
||||
Images A and B will have:
|
||||
A: { "offset": 0, "length": 4 }
|
||||
B: { "offset": 5, "length": 4 }
|
||||
"""
|
||||
|
||||
offset: int
|
||||
"""The start index of the placeholder in the prompt."""
|
||||
|
||||
length: int
|
||||
"""The length of the placeholder."""
|
||||
|
||||
|
||||
MultiModalPlaceholderDict = Mapping[str, List[PlaceholderRange]]
|
||||
"""
|
||||
A dictionary containing placeholder ranges.
|
||||
"""
|
||||
|
||||
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
|
||||
MultiModalKwargs]
|
||||
"""
|
||||
@ -192,6 +35,7 @@ Calculate the maximum number of multimodal tokens input to the language
|
||||
model. This does not include tokens that correspond to the input text.
|
||||
"""
|
||||
|
||||
_T = TypeVar("_T")
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
|
||||
|
||||
@ -224,7 +68,7 @@ class MultiModalPlugin(ABC):
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[object],
|
||||
data: MultiModalData[Any],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalKwargs:
|
||||
"""
|
||||
@ -273,8 +117,8 @@ class MultiModalPlugin(ABC):
|
||||
def map_input(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
data: MultiModalData[object],
|
||||
mm_processor_kwargs: Dict[str, Any],
|
||||
data: MultiModalData[Any],
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]],
|
||||
) -> MultiModalKwargs:
|
||||
"""
|
||||
Transform the data into a dictionary of model inputs using the
|
||||
@ -289,6 +133,7 @@ class MultiModalPlugin(ABC):
|
||||
- :ref:`input_processing_pipeline`
|
||||
- :ref:`enabling_multimodal_inputs`
|
||||
"""
|
||||
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
@ -300,6 +145,9 @@ class MultiModalPlugin(ABC):
|
||||
raise KeyError(f"No input mapper in {self} is registered for "
|
||||
f"model class {model_cls.__name__}.")
|
||||
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
|
||||
# In the case of the default mapper, we have to get resource
|
||||
# processor through its HuggingFace autoclass; since this goes
|
||||
# through **kwargs, we can't inspect it the same way, so we allow
|
||||
@ -508,7 +356,7 @@ class MultiModalPlaceholderMap:
|
||||
self,
|
||||
positions: range,
|
||||
multi_modal_items: List[_T],
|
||||
multi_modal_placeholders: List[PlaceholderRange],
|
||||
multi_modal_placeholders: Sequence[PlaceholderRange],
|
||||
) -> List[_T]:
|
||||
"""
|
||||
Adds the multi-modal items that intersect ```positions`` to this
|
||||
|
@ -3,14 +3,14 @@ from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.image_processing_base import BatchFeature
|
||||
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import get_image_processor
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .base import MultiModalData, MultiModalKwargs, MultiModalPlugin
|
||||
from .base import MultiModalPlugin
|
||||
from .inputs import ImageItem, MultiModalData, MultiModalKwargs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -41,15 +41,11 @@ class ImagePlugin(MultiModalPlugin):
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[object],
|
||||
data: MultiModalData[ImageItem],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalKwargs:
|
||||
model_config = ctx.model_config
|
||||
|
||||
# Processed by input processor
|
||||
if isinstance(data, BatchFeature):
|
||||
return MultiModalKwargs(data.data)
|
||||
|
||||
# PIL image
|
||||
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
||||
image_processor = self._get_hf_image_processor(
|
||||
|
225
vllm/multimodal/inputs.py
Normal file
225
vllm/multimodal/inputs.py
Normal file
@ -0,0 +1,225 @@
|
||||
from collections import UserDict, defaultdict
|
||||
from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple,
|
||||
TypedDict, TypeVar, Union, cast, final)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL.Image import Image
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from vllm.utils import JSONTree, is_list_of, json_map_leaves
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
# yapf: disable
|
||||
ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
|
||||
"""
|
||||
A :class:`transformers.image_utils.ImageInput` representing a single image,
|
||||
which can be passed to a HuggingFace :code:`ImageProcessor`.
|
||||
"""
|
||||
|
||||
VideoItem: TypeAlias = Union[
|
||||
List[Image],
|
||||
np.ndarray,
|
||||
torch.Tensor,
|
||||
List[np.ndarray],
|
||||
List[torch.Tensor],
|
||||
]
|
||||
"""
|
||||
|
||||
A :class:`transformers.image_utils.VideoInput` representing a single video,
|
||||
which can be passed to a HuggingFace :code:`VideoProcessor`.
|
||||
"""
|
||||
|
||||
AudioItem: TypeAlias = Union[
|
||||
np.ndarray,
|
||||
List[float],
|
||||
Tuple[np.ndarray, float], # DEPRECATED: Use mm_processor_kwargs instead
|
||||
]
|
||||
"""
|
||||
Represents a single audio that can be inputted to a HuggingFace
|
||||
:code:`AudioProcessor`.
|
||||
"""
|
||||
# yapf: enable
|
||||
|
||||
MultiModalData: TypeAlias = Union[_T, List[_T]]
|
||||
"""
|
||||
Either a single data item, or a list of data items.
|
||||
|
||||
The number of data items allowed per modality is restricted by
|
||||
:code:`--limit-mm-per-prompt`.
|
||||
"""
|
||||
|
||||
|
||||
@final
|
||||
class MultiModalDataBuiltins(TypedDict, total=False):
|
||||
"""Type annotations for modality types predefined by vLLM."""
|
||||
|
||||
image: MultiModalData[ImageItem]
|
||||
"""The input image(s)."""
|
||||
|
||||
video: MultiModalData[VideoItem]
|
||||
"""The input video(s)."""
|
||||
|
||||
audio: MultiModalData[AudioItem]
|
||||
"""The input audio(s)."""
|
||||
|
||||
|
||||
MultiModalDataDict: TypeAlias = Mapping[str, MultiModalData[Any]]
|
||||
"""
|
||||
A dictionary containing an entry for each modality type to input.
|
||||
|
||||
Note:
|
||||
This dictionary also accepts modality keys defined outside
|
||||
:class:`MultiModalDataBuiltins` as long as a customized plugin
|
||||
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
Read more on that :ref:`here <adding_multimodal_plugin>`.
|
||||
"""
|
||||
|
||||
|
||||
class PlaceholderRange(TypedDict):
|
||||
"""
|
||||
Placeholder location information for multi-modal data.
|
||||
|
||||
For example:
|
||||
Prompt: AAAA BBBB What is in these images?
|
||||
Images A and B will have:
|
||||
A: { "offset": 0, "length": 4 }
|
||||
B: { "offset": 5, "length": 4 }
|
||||
"""
|
||||
|
||||
offset: int
|
||||
"""The start index of the placeholder in the prompt."""
|
||||
|
||||
length: int
|
||||
"""The length of the placeholder."""
|
||||
|
||||
|
||||
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
|
||||
"""
|
||||
Uses a list instead of a tensor if the dimensions of each element do not match.
|
||||
"""
|
||||
|
||||
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
|
||||
"""
|
||||
A dictionary containing nested tensors which have been batched via
|
||||
:meth:`MultiModalKwargs.batch`.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
"""
|
||||
A dictionary that represents the keyword arguments to
|
||||
:meth:`~torch.nn.Module.forward`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
|
||||
"""
|
||||
Stack the inner dimensions that have the same shape in
|
||||
a nested list of tensors.
|
||||
|
||||
Thus, a dimension represented by a list means that the inner
|
||||
dimensions are different for each element along that dimension.
|
||||
"""
|
||||
if isinstance(nested_tensors, torch.Tensor):
|
||||
return nested_tensors
|
||||
|
||||
# TODO: Remove these once all models have been migrated
|
||||
if isinstance(nested_tensors, np.ndarray):
|
||||
return torch.from_numpy(nested_tensors)
|
||||
if isinstance(nested_tensors, (int, float)):
|
||||
return torch.tensor(nested_tensors)
|
||||
|
||||
stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
|
||||
if not is_list_of(stacked, torch.Tensor, check="all"):
|
||||
# Only tensors (not lists) can be stacked.
|
||||
return stacked
|
||||
|
||||
tensors_ = cast(List[torch.Tensor], stacked)
|
||||
if any(t.shape != tensors_[0].shape for t in tensors_):
|
||||
# The tensors have incompatible shapes and can't be stacked.
|
||||
return tensors_
|
||||
|
||||
return torch.stack(tensors_)
|
||||
|
||||
@staticmethod
|
||||
def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs:
|
||||
"""
|
||||
Batch multiple inputs together into a dictionary.
|
||||
|
||||
The resulting dictionary has the same keys as the inputs.
|
||||
If the corresponding value from each input is a tensor and they all
|
||||
share the same shape, the output value is a single batched tensor;
|
||||
otherwise, the output value is a list containing the original value
|
||||
from each input.
|
||||
"""
|
||||
if len(inputs_list) == 0:
|
||||
return {}
|
||||
|
||||
# We need to consider the case where each item in the batch
|
||||
# contains different modalities (i.e. different keys).
|
||||
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
|
||||
|
||||
for inputs in inputs_list:
|
||||
for k, v in inputs.items():
|
||||
item_lists[k].append(v)
|
||||
|
||||
return {
|
||||
k: MultiModalKwargs._try_stack(item_list)
|
||||
for k, item_list in item_lists.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def as_kwargs(
|
||||
batched_inputs: BatchedTensorInputs,
|
||||
*,
|
||||
device: torch.types.Device,
|
||||
) -> BatchedTensorInputs:
|
||||
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
|
||||
|
||||
json_mapped = json_map_leaves(
|
||||
lambda x: x.to(device, non_blocking=True),
|
||||
json_inputs,
|
||||
)
|
||||
|
||||
return cast(BatchedTensorInputs, json_mapped)
|
||||
|
||||
|
||||
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
|
||||
"""
|
||||
A dictionary containing placeholder ranges.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalInputsV2(TypedDict):
|
||||
"""
|
||||
Represents the outputs of :class:`vllm.multimodal.MultiModalProcessor`,
|
||||
ready to be passed to vLLM internals.
|
||||
"""
|
||||
|
||||
type: Literal["multimodal"]
|
||||
"""The type of inputs."""
|
||||
|
||||
prompt: str
|
||||
"""
|
||||
The original, unprocessed prompt text.
|
||||
|
||||
Note:
|
||||
Since prompt text is not required by vLLM internals, we leave this
|
||||
unprocessed to save CPU computation. You can still call
|
||||
:code:`tokenizer.decode(prompt_token_ids)` to get the processed text.
|
||||
"""
|
||||
|
||||
prompt_token_ids: List[int]
|
||||
"""The processed token IDs which includes placeholder tokens."""
|
||||
|
||||
mm_kwargs: MultiModalKwargs
|
||||
"""Keyword arguments to be directly passed to the model after batching."""
|
||||
|
||||
mm_placeholders: MultiModalPlaceholderDict
|
||||
"""
|
||||
For each modality, information about the placeholder tokens in
|
||||
:code:`prompt_token_ids`.
|
||||
"""
|
273
vllm/multimodal/processing.py
Normal file
273
vllm/multimodal/processing.py
Normal file
@ -0,0 +1,273 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache, partial
|
||||
from typing import (Any, Callable, Collection, Generic, List, Mapping,
|
||||
Optional, TypedDict, TypeVar, final)
|
||||
|
||||
from transformers import BatchFeature
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
|
||||
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
|
||||
VideoItem)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
ReplacementFunc: TypeAlias = Callable[[_T, BatchFeature, int], List[int]]
|
||||
"""
|
||||
Given the original data item, HF-processed data, and index of the processed
|
||||
item, output the replacement token IDs to be allocated in vLLM.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModalityProcessingMetadata(Generic[_T]):
|
||||
placeholder_replacements: Mapping[str, ReplacementFunc]
|
||||
"""
|
||||
A dictionary where each item represents the original placeholder in the
|
||||
prompt text and the corresponding replacement.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalProcessingMetadataBuiltins(TypedDict, total=False):
|
||||
"""Type annotations for modality types predefined by vLLM."""
|
||||
|
||||
image: ModalityProcessingMetadata[ImageItem]
|
||||
video: ModalityProcessingMetadata[VideoItem]
|
||||
audio: ModalityProcessingMetadata[AudioItem]
|
||||
|
||||
|
||||
MultiModalProcessingMetadata: TypeAlias = \
|
||||
Mapping[str, ModalityProcessingMetadata[Any]]
|
||||
"""
|
||||
A dictionary containing an entry for each modality type to process.
|
||||
|
||||
Note:
|
||||
This dictionary also accepts modality keys defined outside
|
||||
:class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin
|
||||
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
Read more on that :ref:`here <adding_multimodal_plugin>`.
|
||||
"""
|
||||
|
||||
MultiModalMultiData: TypeAlias = List[_T]
|
||||
"""
|
||||
A list of data items, where the number of data items allowed
|
||||
per modality is restricted by :code:`--limit-mm-per-prompt`.
|
||||
"""
|
||||
|
||||
|
||||
@final
|
||||
class MultiModalMultiDataBuiltins(TypedDict, total=False):
|
||||
"""Type annotations for modality types predefined by vLLM."""
|
||||
|
||||
image: MultiModalMultiData[ImageItem]
|
||||
"""The input images."""
|
||||
|
||||
video: MultiModalMultiData[VideoItem]
|
||||
"""The input videos."""
|
||||
|
||||
audio: MultiModalMultiData[AudioItem]
|
||||
"""The input audios."""
|
||||
|
||||
|
||||
MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]]
|
||||
"""
|
||||
A dictionary containing an entry for each modality type to input.
|
||||
|
||||
Note:
|
||||
This dictionary also accepts modality keys defined outside
|
||||
:class:`MultiModalMultiDataBuiltins` as long as a customized plugin
|
||||
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
Read more on that :ref:`here <adding_multimodal_plugin>`.
|
||||
"""
|
||||
|
||||
|
||||
def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict:
|
||||
"""
|
||||
Convert a :class:`MultiModalDataDict` containing single data items
|
||||
to a :class:`MultiModalMultiDataDict` containing multiple data items
|
||||
per entry.
|
||||
"""
|
||||
multi_data: Mapping[str, MultiModalMultiData[Any]] = {}
|
||||
|
||||
for k, v in data.items():
|
||||
# yapf: disable
|
||||
if k == "video":
|
||||
# Special case since even a single item can be a list
|
||||
multi_data[k] = v if is_list_of(v, list) else [v] # type: ignore[index]
|
||||
elif k in ("image", "audio"):
|
||||
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
|
||||
else:
|
||||
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
|
||||
# yapf: enable
|
||||
|
||||
return multi_data
|
||||
|
||||
|
||||
def encode_no_special_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
text: str,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Backend-agnostic equivalent of HF's
|
||||
:code:`tokenizer.encode(text, add_special_tokens=False)`.
|
||||
"""
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
return tokenizer.tokenizer.encode(text, bos=False, eos=False)
|
||||
|
||||
return tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def candidate_placeholders(
|
||||
tokenizer: AnyTokenizer,
|
||||
placeholder_text: str,
|
||||
) -> Collection[List[int]]:
|
||||
"""Generate token ID sequences that may represent a placeholder text."""
|
||||
# When the placeholder text is not mapped to a special token ID,
|
||||
# it may be tokenized differently based on whether it is at the start/end
|
||||
# of the string. So, we go through each combination of whether the text
|
||||
# is at the start and end boundaries of the string
|
||||
|
||||
# Matches the placeholder when it is in the middle of the string
|
||||
start_id, = encode_no_special_tokens(tokenizer, "a")
|
||||
end_id, = encode_no_special_tokens(tokenizer, "b")
|
||||
|
||||
candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text)
|
||||
|
||||
start_id_, *candidate_a = encode_no_special_tokens(
|
||||
tokenizer,
|
||||
f"a{placeholder_text}",
|
||||
)
|
||||
assert start_id == start_id_
|
||||
|
||||
start_id_, *candidate_ab, end_id_ = encode_no_special_tokens(
|
||||
tokenizer,
|
||||
f"a{placeholder_text}b",
|
||||
)
|
||||
assert start_id == start_id_ and end_id == end_id_
|
||||
|
||||
*candidate_b, end_id_ = encode_no_special_tokens(
|
||||
tokenizer,
|
||||
f"{placeholder_text}b",
|
||||
)
|
||||
assert end_id == end_id_
|
||||
|
||||
# Remove duplicates (need to convert to tuple to be hashable)
|
||||
unique_candidates = {
|
||||
tuple(c)
|
||||
for c in [candidate_basic, candidate_a, candidate_ab, candidate_b]
|
||||
}
|
||||
|
||||
# Convert back to list
|
||||
return [list(c) for c in unique_candidates]
|
||||
|
||||
|
||||
def apply_placeholders(
|
||||
token_ids: List[int],
|
||||
placeholder_ids: List[int],
|
||||
get_replacement_ids: Callable[[], List[int]],
|
||||
) -> Optional[PlaceholderRange]:
|
||||
"""
|
||||
Find the first occurrence of :code:`placeholder_ids`,
|
||||
and replace it with the output of :code:`get_replacement_ids`.
|
||||
|
||||
This function updates :code:`token_ids` in place.
|
||||
"""
|
||||
placeholder_length = len(placeholder_ids)
|
||||
|
||||
for start_idx in range(len(token_ids) - placeholder_length + 1):
|
||||
if token_ids[start_idx:placeholder_length] == placeholder_ids:
|
||||
token_ids[start_idx:placeholder_length] = get_replacement_ids()
|
||||
|
||||
return PlaceholderRange(offset=start_idx,
|
||||
length=placeholder_length)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MultiModalProcessor:
|
||||
"""
|
||||
Helper class to process multi-modal inputs to be used in vLLM.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
metadata: MultiModalProcessingMetadata,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.ctx = ctx
|
||||
self.metadata = metadata
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
return self.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
tokenizer = self.ctx.tokenizer
|
||||
hf_processor = self.ctx.get_hf_processor()
|
||||
|
||||
processed_inputs = hf_processor(
|
||||
text=prompt, # type: ignore
|
||||
**mm_data,
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
new_token_ids, = processed_inputs.pop("input_ids").tolist()
|
||||
mm_kwargs = MultiModalKwargs(processed_inputs)
|
||||
|
||||
mm_placeholders: Mapping[str, List[PlaceholderRange]] = {}
|
||||
|
||||
for modality, orig_inputs in to_multi_format(mm_data).items():
|
||||
assert isinstance(orig_inputs, list)
|
||||
|
||||
metadata = self.metadata[modality]
|
||||
placeholder_replacements = metadata.placeholder_replacements
|
||||
|
||||
modality_placeholders: List[PlaceholderRange] = []
|
||||
|
||||
for item_idx, orig_item in enumerate(orig_inputs):
|
||||
for match_text, replace_fn in placeholder_replacements.items():
|
||||
candidates = candidate_placeholders(tokenizer, match_text)
|
||||
get_replacement_ids = partial(
|
||||
replace_fn,
|
||||
orig_item,
|
||||
processed_inputs,
|
||||
item_idx,
|
||||
)
|
||||
|
||||
for match_ids in candidates:
|
||||
# TODO(youkaichao): Don't update new_token_ids
|
||||
placeholders = apply_placeholders(
|
||||
new_token_ids,
|
||||
match_ids,
|
||||
get_replacement_ids,
|
||||
)
|
||||
|
||||
if placeholders is not None:
|
||||
modality_placeholders.append(placeholders)
|
||||
|
||||
# yapf: disable
|
||||
mm_placeholders[modality] = modality_placeholders # type: ignore[index]
|
||||
# yapf: enable
|
||||
|
||||
return MultiModalInputsV2(
|
||||
type="multimodal",
|
||||
prompt=prompt,
|
||||
prompt_token_ids=new_token_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_placeholders=mm_placeholders,
|
||||
)
|
@ -1,13 +1,20 @@
|
||||
import functools
|
||||
from collections import UserDict
|
||||
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
|
||||
Sequence, Type, TypeVar)
|
||||
|
||||
import torch.nn as nn
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .audio import AudioPlugin
|
||||
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalKwargs,
|
||||
MultiModalPlugin, MultiModalTokensCalc, NestedTensors)
|
||||
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
|
||||
from .image import ImagePlugin
|
||||
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
|
||||
from .processing import MultiModalProcessor
|
||||
from .video import VideoPlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -15,8 +22,18 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
|
||||
class _MultiModalLimits(UserDict):
|
||||
MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext],
|
||||
MultiModalProcessor]
|
||||
"""
|
||||
Constructs a :class:`MultiModalProcessor` instance from the context.
|
||||
|
||||
The processing metadata should be derived from the context.
|
||||
"""
|
||||
|
||||
|
||||
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
|
||||
"""
|
||||
Wraps `_limits_by_model` for a more informative error message
|
||||
when attempting to access a model that does not exist.
|
||||
@ -45,6 +62,9 @@ class MultiModalRegistry:
|
||||
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
|
||||
self._plugins = {p.get_data_key(): p for p in plugins}
|
||||
|
||||
self._processor_factories: Dict[Type[nn.Module],
|
||||
MultiModalProcessorFactory] = {}
|
||||
|
||||
# This is used for non-multimodal models
|
||||
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
|
||||
|
||||
@ -243,3 +263,59 @@ class MultiModalRegistry:
|
||||
This should be called after :meth:`init_mm_limits_per_prompt`.
|
||||
"""
|
||||
return self._limits_by_model[model_config]
|
||||
|
||||
def register_processor(
|
||||
self,
|
||||
factory: MultiModalProcessorFactory,
|
||||
):
|
||||
"""
|
||||
Register a multi-modal processor to a model class.
|
||||
|
||||
When the model receives multi-modal data, the provided function is
|
||||
invoked to transform the data into a dictionary of model inputs.
|
||||
|
||||
See also:
|
||||
- :ref:`input_processing_pipeline`
|
||||
- :ref:`enabling_multimodal_inputs`
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._processor_factories:
|
||||
logger.warning(
|
||||
"Model class %s already has an input mapper "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._processor_factories[model_cls] = factory
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def has_processor(self, model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Test whether a multi-modal processor is defined for a specific model.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
return model_cls in self._processor_factories
|
||||
|
||||
def create_processor(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> MultiModalProcessor:
|
||||
"""
|
||||
Create a multi-modal processor for a specific model and tokenizer.
|
||||
"""
|
||||
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
processor_factory = self._processor_factories[model_cls]
|
||||
|
||||
ctx = InputProcessingContext(model_config, tokenizer)
|
||||
return processor_factory(ctx)
|
||||
|
@ -11,9 +11,10 @@ from PIL import Image
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal.base import MultiModalDataDict, PlaceholderRange
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
from .inputs import MultiModalDataDict, PlaceholderRange
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
cached_get_tokenizer = lru_cache(get_tokenizer)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -9,8 +9,9 @@ from vllm.transformers_utils.processor import get_video_processor
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .base import MultiModalData, MultiModalKwargs
|
||||
from .base import MultiModalData
|
||||
from .image import ImagePlugin
|
||||
from .inputs import MultiModalKwargs, VideoItem
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -20,17 +21,6 @@ logger = init_logger(__name__)
|
||||
cached_get_video_processor = lru_cache(get_video_processor)
|
||||
cached_get_tokenizer = lru_cache(get_tokenizer)
|
||||
|
||||
VideoInput = Union[
|
||||
"np.ndarray", # single video input
|
||||
List["np.ndarray"],
|
||||
# TODO: support more types
|
||||
# List[Image.Image], List[List[Image.Image]],
|
||||
# "torch.Tensor",
|
||||
# List["torch.Tensor"],
|
||||
# List[List["np.ndarrray"]],
|
||||
# List[List["torch.Tensor"]],
|
||||
]
|
||||
|
||||
|
||||
class VideoPlugin(ImagePlugin):
|
||||
"""Plugin for video data."""
|
||||
@ -53,13 +43,13 @@ class VideoPlugin(ImagePlugin):
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[object],
|
||||
data: MultiModalData[VideoItem],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalKwargs:
|
||||
model_config = ctx.model_config
|
||||
|
||||
if isinstance(data, list) and len(data) == 1:
|
||||
data = data[0]
|
||||
data = data[0] # type: ignore
|
||||
|
||||
if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray):
|
||||
video_processor = self._get_hf_video_processor(
|
||||
|
@ -5,25 +5,21 @@ from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, reduce
|
||||
from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List,
|
||||
Mapping, Optional)
|
||||
from functools import reduce
|
||||
from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.inputs import SingletonInputs, SingletonInputsAdapter
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.inputs import SingletonInputs
|
||||
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
|
||||
VLLM_INVALID_TOKEN_ID = -1
|
||||
@ -407,14 +403,14 @@ class Sequence:
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
inputs: "SingletonInputs",
|
||||
inputs: SingletonInputs,
|
||||
block_size: int,
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.inputs = inputs
|
||||
self.inputs = SingletonInputsAdapter(inputs)
|
||||
self.block_size = block_size
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
@ -441,59 +437,29 @@ class Sequence:
|
||||
def n_blocks(self) -> int:
|
||||
return (self.get_len() + self.block_size - 1) // self.block_size
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def prompt(self) -> Optional[str]:
|
||||
inputs = self.inputs
|
||||
return self.inputs.prompt
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("prompt")
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
inputs = self.inputs
|
||||
return self.inputs.prompt_token_ids
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("prompt_token_ids", [])
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
||||
inputs = self.inputs
|
||||
return self.inputs.prompt_embeds
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return None
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("multi_modal_data", {})
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("mm_processor_kwargs", {})
|
||||
|
||||
assert_never(inputs)
|
||||
return self.inputs.multi_modal_data
|
||||
|
||||
@property
|
||||
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
|
||||
inputs = self.inputs
|
||||
return self.inputs.multi_modal_placeholders
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("multi_modal_placeholders", {})
|
||||
|
||||
assert_never(inputs)
|
||||
@property
|
||||
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||
return self.inputs.mm_processor_kwargs
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
|
@ -6,6 +6,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
@ -321,6 +322,9 @@ class AsyncLLM(EngineClient):
|
||||
async def get_decoding_config(self):
|
||||
raise ValueError("Not Supported on V1 yet.")
|
||||
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
return self.processor.input_preprocessor
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
|
@ -7,6 +7,7 @@ from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
@ -32,6 +33,7 @@ class LLMEngine:
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
multiprocess_mode: bool = False,
|
||||
) -> None:
|
||||
@ -50,7 +52,7 @@ class LLMEngine:
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
self.processor = Processor(vllm_config.model_config,
|
||||
vllm_config.lora_config, self.tokenizer,
|
||||
input_registry)
|
||||
input_registry, mm_registry)
|
||||
|
||||
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
|
||||
self.detokenizer = Detokenizer(
|
||||
|
@ -2,15 +2,17 @@ import time
|
||||
from typing import Any, Dict, Mapping, Optional, Tuple, Union
|
||||
|
||||
from vllm.config import LoRAConfig, ModelConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||
EncoderDecoderLLMInputs, InputRegistry, PromptType)
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
PromptType, SingletonInputsAdapter)
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
|
||||
|
||||
|
||||
@ -20,8 +22,9 @@ class Processor:
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: BaseTokenizerGroup,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
|
||||
self.model_config = model_config
|
||||
@ -31,7 +34,8 @@ class Processor:
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
model_config)
|
||||
self.input_preprocessor = InputPreprocessor(model_config,
|
||||
self.tokenizer)
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
self.input_processor = input_registry.create_input_processor(
|
||||
model_config)
|
||||
|
||||
@ -73,6 +77,19 @@ class Processor:
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
if is_encoder_decoder_inputs(processed_inputs):
|
||||
decoder_inputs = SingletonInputsAdapter(
|
||||
processed_inputs["decoder"])
|
||||
encoder_inputs = SingletonInputsAdapter(
|
||||
processed_inputs["encoder"])
|
||||
else:
|
||||
decoder_inputs = SingletonInputsAdapter(processed_inputs)
|
||||
encoder_inputs = None
|
||||
|
||||
# TODO: Impl encoder-decoder
|
||||
if encoder_inputs is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
assert isinstance(params, SamplingParams)
|
||||
# TODO: can we avoid cloning here in multiproc case
|
||||
sampling_params = params.clone()
|
||||
@ -81,27 +98,43 @@ class Processor:
|
||||
|
||||
# Make Request for Detokenizer.
|
||||
detokenizer_request = DetokenizerRequest(
|
||||
request_id, processed_inputs.get("prompt"),
|
||||
processed_inputs.get("prompt_token_ids"),
|
||||
request_id,
|
||||
decoder_inputs.prompt,
|
||||
decoder_inputs.prompt_token_ids,
|
||||
sampling_params.skip_special_tokens,
|
||||
sampling_params.spaces_between_special_tokens,
|
||||
sampling_params.output_kind, sampling_params.stop,
|
||||
sampling_params.include_stop_str_in_output)
|
||||
sampling_params.output_kind,
|
||||
sampling_params.stop,
|
||||
sampling_params.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
# Make Request for EngineCore.
|
||||
engine_core_request = EngineCoreRequest(
|
||||
request_id, processed_inputs.get("prompt"),
|
||||
processed_inputs.get("prompt_token_ids"),
|
||||
processed_inputs.get("multi_modal_data"),
|
||||
processed_inputs.get("multi_modal_placeholders"),
|
||||
processed_inputs.get("mm_processor_kwargs"), sampling_params,
|
||||
eos_token_id, arrival_time, lora_request)
|
||||
request_id,
|
||||
decoder_inputs.prompt,
|
||||
decoder_inputs.prompt_token_ids,
|
||||
decoder_inputs.multi_modal_data,
|
||||
decoder_inputs.multi_modal_placeholders,
|
||||
decoder_inputs.mm_processor_kwargs,
|
||||
sampling_params,
|
||||
eos_token_id,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
)
|
||||
|
||||
return detokenizer_request, engine_core_request
|
||||
|
||||
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
|
||||
EncoderDecoderLLMInputs]):
|
||||
prompt_ids = inputs.get("prompt_token_ids")
|
||||
def _validate_model_inputs(self, inputs: ProcessorInputs):
|
||||
if is_encoder_decoder_inputs(inputs):
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
prompt_inputs = inputs["decoder" if self.model_config.
|
||||
is_multimodal_model else "encoder"]
|
||||
else:
|
||||
prompt_inputs = inputs
|
||||
|
||||
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
|
||||
|
||||
if prompt_ids is None or len(prompt_ids) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
@ -117,6 +150,10 @@ class Processor:
|
||||
"inputs, the number of image tokens depends on the number "
|
||||
"of images, and possibly their aspect ratios as well.")
|
||||
|
||||
# TODO: Find out how many placeholder tokens are there so we can
|
||||
# check that chunked prefill does not truncate them
|
||||
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||||
|
||||
|
||||
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||
config = try_get_generation_config(
|
||||
|
@ -1,7 +1,7 @@
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.inputs.data import DecoderOnlyInputs
|
||||
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -9,23 +9,20 @@ from vllm.sequence import RequestMetrics
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.inputs import DecoderOnlyInputs
|
||||
|
||||
|
||||
class Request:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: "DecoderOnlyInputs",
|
||||
inputs: DecoderOnlyInputs,
|
||||
sampling_params: SamplingParams,
|
||||
eos_token_id: Optional[int],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.inputs = inputs
|
||||
self.inputs = SingletonInputsAdapter(inputs)
|
||||
self.sampling_params = sampling_params
|
||||
# Because of LoRA, the eos token id can be different for each request.
|
||||
self.eos_token_id = eos_token_id
|
||||
@ -41,17 +38,17 @@ class Request:
|
||||
assert sampling_params.max_tokens is not None
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
|
||||
self.prompt = inputs.get("prompt")
|
||||
self.prompt_token_ids = inputs["prompt_token_ids"]
|
||||
self.prompt = self.inputs.prompt
|
||||
self.prompt_token_ids = self.inputs.prompt_token_ids
|
||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||
self._output_token_ids: List[int] = []
|
||||
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
|
||||
self.num_computed_tokens = 0
|
||||
|
||||
# Raw multimodal data before the mm input mapper (e.g., PIL images).
|
||||
self.mm_data = inputs.get("multi_modal_data")
|
||||
self.mm_processor_kwargs = inputs.get("mm_processor_kwargs")
|
||||
mm_positions = inputs.get("multi_modal_placeholders")
|
||||
self.mm_data = self.inputs.multi_modal_data
|
||||
self.mm_processor_kwargs = self.inputs.mm_processor_kwargs
|
||||
mm_positions = self.inputs.multi_modal_placeholders
|
||||
if mm_positions:
|
||||
# FIXME(woosuk): Support other modalities.
|
||||
self.mm_positions = mm_positions.get("image", [])
|
||||
@ -64,8 +61,7 @@ class Request:
|
||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
inputs=DecoderOnlyInputs(
|
||||
type="token",
|
||||
inputs=token_inputs(
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt=request.prompt,
|
||||
multi_modal_data=request.mm_data,
|
||||
@ -114,7 +110,7 @@ class Request:
|
||||
return RequestStatus.get_finished_reason(self.status)
|
||||
|
||||
def has_encoder_inputs(self) -> bool:
|
||||
return self.mm_data is not None
|
||||
return len(self.mm_data) > 0
|
||||
|
||||
@property
|
||||
def num_encoder_inputs(self) -> int:
|
||||
|
@ -28,7 +28,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.base import PlaceholderRange
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
@ -148,19 +148,29 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
query_lens=seq_lens,
|
||||
)
|
||||
|
||||
def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata,
|
||||
seq_data: SequenceData, computed_len: int,
|
||||
mm_processor_kwargs: Dict[str, Any]):
|
||||
|
||||
def _compute_multi_modal_input(
|
||||
self,
|
||||
seq_data: SequenceData,
|
||||
computed_len: int,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
):
|
||||
# NOTE: mm_data only includes the subset of multi-modal items that
|
||||
# intersect with the current prefill positions.
|
||||
mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
|
||||
seq_group, range(computed_len, len(seq_data.get_token_ids())))
|
||||
seq_group_metadata,
|
||||
range(computed_len, len(seq_data.get_token_ids())),
|
||||
)
|
||||
|
||||
if not mm_data:
|
||||
return
|
||||
return None, None, None
|
||||
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs)
|
||||
if self.runner.mm_registry.has_processor(self.runner.model_config):
|
||||
mm_kwargs = mm_data
|
||||
else:
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
seq_group_metadata.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
# special processing for mrope position deltas.
|
||||
mrope_positions = None
|
||||
@ -202,7 +212,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
multi_model_kwargs_list: List[MultiModalKwargs] = []
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
multi_modal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
@ -223,11 +233,14 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
|
||||
mrope_positions = None
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
mm_kwargs, placeholder_maps, mrope_positions = self \
|
||||
._compute_multi_modal_input(
|
||||
seq_group_metadata, seq_data, computed_len,
|
||||
seq_group_metadata.mm_processor_kwargs)
|
||||
multi_model_kwargs_list.append(mm_kwargs)
|
||||
(
|
||||
mm_kwargs,
|
||||
placeholder_maps,
|
||||
mrope_positions,
|
||||
) = self._compute_multi_modal_input(seq_data, computed_len,
|
||||
seq_group_metadata)
|
||||
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
for modality, placeholder_map in placeholder_maps.items():
|
||||
multi_modal_placeholder_maps[modality].extend(
|
||||
placeholder_map)
|
||||
@ -302,7 +315,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs)
|
||||
|
@ -716,7 +716,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
context_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
prefix_block_tables: List[List[int]] = []
|
||||
multi_model_kwargs_list: List[MultiModalKwargs] = []
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
return PreparePromptMetadata.empty()
|
||||
@ -777,7 +777,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
multi_model_kwargs_list.append(mm_kwargs)
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
|
||||
if seq_group_metadata.block_tables is None:
|
||||
# During memory profiling, the block tables are not initialized
|
||||
@ -876,7 +876,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
multi_modal_placeholder_index_maps=
|
||||
None # FIXME(kzawora): mutli-modality will not work here
|
||||
)
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return PreparePromptMetadata(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
|
@ -252,7 +252,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
|
||||
# Multi-modal inputs.
|
||||
multi_model_kwargs: Optional[MultiModalKwargs] = None,
|
||||
multi_modal_kwargs: Optional[MultiModalKwargs] = None,
|
||||
multi_modal_placeholder_maps: Optional[Dict[
|
||||
str, MultiModalPlaceholderMap]] = None,
|
||||
|
||||
@ -373,7 +373,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
prompt_adapter_prompt_mapping or [])
|
||||
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.multi_model_kwargs = multi_model_kwargs
|
||||
self.multi_modal_kwargs = multi_modal_kwargs
|
||||
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
|
||||
self.prefix_cache_hit = prefix_cache_hit
|
||||
|
||||
@ -661,10 +661,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
if not mm_data:
|
||||
return
|
||||
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs)
|
||||
inter_data.multi_model_kwargs = mm_kwargs
|
||||
if self.runner.mm_registry.has_processor(self.runner.model_config):
|
||||
mm_kwargs = mm_data
|
||||
else:
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
seq_group_metadata.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
inter_data.multi_modal_kwargs = mm_kwargs
|
||||
inter_data.multi_modal_placeholder_maps = placeholder_maps
|
||||
|
||||
# special processing for mrope position deltas.
|
||||
@ -938,11 +943,11 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
)
|
||||
|
||||
# Multi-modal data.
|
||||
multi_model_kwargs_list = [
|
||||
data.multi_model_kwargs for data in self.inter_data_list
|
||||
if data.multi_model_kwargs is not None
|
||||
multi_modal_kwargs_list = [
|
||||
data.multi_modal_kwargs for data in self.inter_data_list
|
||||
if data.multi_modal_kwargs is not None
|
||||
]
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens_tensor,
|
||||
|
@ -67,7 +67,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
|
||||
# Multi-modal data support
|
||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.multi_modal_input_mapper = self.mm_registry \
|
||||
.create_input_mapper(self.model_config)
|
||||
|
||||
# Lazy initialization.
|
||||
@ -122,7 +123,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
input_block_ids: List[int] = []
|
||||
|
||||
seq_lens: List[int] = []
|
||||
multi_model_kwargs_list: List[MultiModalKwargs] = []
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
@ -144,12 +145,15 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
# Process multi-modal data
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs,
|
||||
)
|
||||
multi_model_kwargs_list.append(mm_kwargs)
|
||||
if self.mm_registry.has_processor(self.model_config):
|
||||
mm_kwargs = mm_data
|
||||
else:
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
seq_group_metadata.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
|
||||
max_seq_len = max(seq_lens)
|
||||
assert max_seq_len > 0
|
||||
@ -167,7 +171,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return (input_tokens, input_positions, input_block_ids, seq_lens,
|
||||
multi_modal_kwargs)
|
||||
|
@ -70,7 +70,8 @@ class OpenVINOModelRunner(ModelRunnerBase):
|
||||
)
|
||||
|
||||
# Multi-modal data support
|
||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.multi_modal_input_mapper = self.mm_registry \
|
||||
.create_input_mapper(self.model_config)
|
||||
|
||||
# Lazy initialization.
|
||||
@ -102,7 +103,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
|
||||
seq_lens: List[int] = []
|
||||
past_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
multi_model_kwargs_list: List[MultiModalKwargs] = []
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
multi_modal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
@ -222,11 +223,15 @@ class OpenVINOModelRunner(ModelRunnerBase):
|
||||
mm_data, placeholder_maps = MultiModalPlaceholderMap \
|
||||
.from_seq_group(seq_group_metadata, positions_range)
|
||||
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
mm_processor_kwargs=seq_group_metadata.
|
||||
mm_processor_kwargs)
|
||||
multi_model_kwargs_list.append(mm_kwargs)
|
||||
if self.mm_registry.has_processor(self.model_config):
|
||||
mm_kwargs = mm_data
|
||||
else:
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
seq_group_metadata.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
|
||||
for modality, placeholder_map in placeholder_maps.items():
|
||||
multi_modal_placeholder_maps[modality].extend(
|
||||
@ -275,7 +280,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return ModelInput(
|
||||
input_tokens,
|
||||
|
@ -160,7 +160,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
multi_model_kwargs_list: List[MultiModalKwargs] = []
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
multi_modal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
@ -191,8 +191,16 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
|
||||
mm_data, placeholder_maps = MultiModalPlaceholderMap \
|
||||
.from_seq_group(seq_group_metadata, positions_range)
|
||||
|
||||
mm_kwargs = self.runner.multi_modal_input_mapper(mm_data)
|
||||
multi_model_kwargs_list.append(mm_kwargs)
|
||||
if self.runner.mm_registry.has_processor(
|
||||
self.runner.model_config):
|
||||
mm_kwargs = mm_data
|
||||
else:
|
||||
mm_kwargs = self.runner.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
seq_group_metadata.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
|
||||
for modality, placeholder_map in placeholder_maps.items():
|
||||
multi_modal_placeholder_maps[modality].extend(
|
||||
@ -264,7 +272,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
|
||||
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user