[Core] Dynamic image size support for VLMs (#5276)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: ywang96 <ywang@roblox.com>
Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
Cyrus Leung 2024-07-03 11:34:00 +08:00 committed by GitHub
parent 482045ee77
commit 9831aec49f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 1453 additions and 664 deletions

View File

@ -8,7 +8,7 @@ Input Processing
vLLM provides a mechanism for defining input processors for each model so that the inputs are processed vLLM provides a mechanism for defining input processors for each model so that the inputs are processed
in :class:`~vllm.LLMEngine` before they are passed to model executors. in :class:`~vllm.LLMEngine` before they are passed to model executors.
Currently, this mechanism is only utilized in **multi-modal models** for preprocessing multi-modal input Currently, this mechanism is only utilized in :ref:`multi-modal models <multi_modality>` for preprocessing multi-modal input
data in addition to input prompt, but it can be extended to text-only language models when needed. data in addition to input prompt, but it can be extended to text-only language models when needed.
Guides Guides

View File

@ -0,0 +1,124 @@
.. _adding_a_new_multimodal_model:
Adding a New Multimodal Model
=============================
This document provides a high-level guide on integrating a :ref:`multi-modal model <multi_modality>` into vLLM.
.. note::
The complexity of adding a new model depends heavily on the model's architecture.
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.
.. tip::
If you are encountering issues while integrating your model into vLLM, feel free to open an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ repository.
We will be happy to help you out!
1. Set up the base vLLM model
-----------------------------
As usual, follow :ref:`these steps <adding_a_new_model>` to implement the model in vLLM, but note the following:
- You should additionally implement the :class:`~vllm.model_executor.models.interfaces.SupportsVision` interface.
.. code-block:: diff
+ from vllm.model_executor.models.interfaces import SupportsVision
- class YourModelForImage2Seq(nn.Module):
+ class YourModelForImage2Seq(nn.Module, SupportsVision):
.. note::
The model class does not have to be named :code:`*ForCausalLM`.
Check out `the HuggingFace Transformers documentation <https://huggingface.co/docs/transformers/model_doc/auto#multimodal>`__ for some examples.
- While implementing the :meth:`~torch.nn.Module.forward` method, reserve a keyword parameter
for each input tensor that corresponds to a multi-modal input, as shown in the following example:
.. code-block:: diff
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
+ pixel_values: torch.Tensor,
) -> SamplerOutput:
2. Register input mappers
-------------------------
For each modality type to support, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`.
.. code-block:: diff
from vllm.model_executor.models.interfaces import SupportsVision
+ from vllm.multimodal import MULTIMODAL_REGISTRY
+ @MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
+ @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision):
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
.. seealso::
:ref:`input_processing_pipeline`
3. (Optional) Register dummy data
---------------------------------
During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
Here are some examples:
- Image inputs (static feature size): `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Image inputs (dynamic feature size): `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__
.. seealso::
:ref:`input_processing_pipeline`
4. (Optional) Register input processor
--------------------------------------
Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor.
This is often due to the fact that unlike implementations in HuggingFace Transformers, the reshaping and/or expansion of multi-modal embeddings needs to take place outside model's :meth:`~torch.nn.Module.forward` call.
You can register input processors via :meth:`INPUT_REGISTRY.register_input_processor <vllm.inputs.registry.InputRegistry.register_input_processor>`.
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples:
- Insert static number of image tokens: `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Insert dynamic number of image tokens: `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__
.. seealso::
:ref:`input_processing_pipeline`

View File

@ -1,3 +1,5 @@
.. _multi_modality:
Multi-Modality Multi-Modality
============== ==============
@ -8,12 +10,18 @@ vLLM provides experimental support for multi-modal models through the :mod:`vllm
:class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data`` :class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data``
which allows you to pass in multi-modal input alongside text and token prompts. which allows you to pass in multi-modal input alongside text and token prompts.
By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, please follow :ref:`the guide for adding a new multimodal model. <adding_a_new_multimodal_model>`.
you must decorate the model class with :meth:`InputRegistry.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`,
as well as :meth:`MULTIMODAL_REGISTRY.register_input_mapper <MultiModalRegistry.register_input_mapper>` for each modality type to support.
# TODO: Add more instructions on how to do that once embeddings is in. # TODO: Add more instructions on how to do that once embeddings is in.
Guides
++++++
.. toctree::
:maxdepth: 1
adding_multimodal_model
Module Contents Module Contents
+++++++++++++++ +++++++++++++++
@ -35,6 +43,10 @@ Base Classes
:members: :members:
:show-inheritance: :show-inheritance:
.. autoclass:: vllm.multimodal.MultiModalInputs
:members:
:show-inheritance:
.. autoclass:: vllm.multimodal.MultiModalPlugin .. autoclass:: vllm.multimodal.MultiModalPlugin
:members: :members:
:show-inheritance: :show-inheritance:

View File

@ -23,7 +23,6 @@ The following :ref:`engine arguments <engine_args>` are specific to VLMs:
Currently, the support for vision language models on vLLM has the following limitations: Currently, the support for vision language models on vLLM has the following limitations:
* Only single image input is supported per text prompt. * Only single image input is supported per text prompt.
* Dynamic ``image_input_shape`` is not supported: the input image will be resized to the static ``image_input_shape``. This means our LLaVA-NeXT output may not exactly match the huggingface implementation.
We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests. We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.
@ -42,12 +41,17 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
) )
.. important:: .. important::
Currently, you have to specify ``image_feature_size`` to support memory profiling.
To avoid OOM during runtime, you should set this to the maximum value supported by the model.
The calculation of feature size is specific to the model. For more details, please refer to
the function :code:`get_<model_name>_image_feature_size` inside the corresponding model file.
We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration. We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration.
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`: To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
* ``prompt``: The prompt should have a number of ``<image>`` tokens equal to ``image_feature_size``. * ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
.. note:: .. note::
@ -57,8 +61,8 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS
.. code-block:: python .. code-block:: python
prompt = "<image>" * 576 + ( # Refer to the HuggingFace repo for the correct format to use
"\nUSER: What is the content of this image?\nASSISTANT:") prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
# Load the image using PIL.Image # Load the image using PIL.Image
image = ... image = ...
@ -74,8 +78,6 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_. A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_.
.. important::
We will remove the need to format image tokens in a future release. Afterwards, the input text will follow the same format as that for the original HuggingFace model.
Online OpenAI Vision API Compatible Inference Online OpenAI Vision API Compatible Inference
---------------------------------------------- ----------------------------------------------
@ -103,6 +105,11 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
--chat-template template_llava.jinja --chat-template template_llava.jinja
.. important:: .. important::
Currently, you have to specify ``image_feature_size`` to support memory profiling.
To avoid OOM during runtime, you should set this to the maximum value supported by the model.
The calculation of feature size is specific to the model. For more details, please refer to
the function :code:`get_<model_name>_image_feature_size` inside the corresponding model file.
We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration. We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration.
To consume the server, you can use the OpenAI client like in the example below: To consume the server, you can use the OpenAI client like in the example below:
@ -121,6 +128,8 @@ To consume the server, you can use the OpenAI client like in the example below:
messages=[{ messages=[{
"role": "user", "role": "user",
"content": [ "content": [
# NOTE: The prompt formatting with the image token `<image>` is not needed
# since the prompt will be processed automatically by the API server.
{"type": "text", "text": "What's in this image?"}, {"type": "text", "text": "What's in this image?"},
{ {
"type": "image_url", "type": "image_url",
@ -144,5 +153,4 @@ A full code example can be found in `examples/openai_vision_api_client.py <https
export VLLM_IMAGE_FETCH_TIMEOUT=<timeout> export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
.. note:: .. note::
The prompt formatting with the image token ``<image>`` is not needed when serving VLMs with the API server since the prompt will be There is no need to format the prompt in the API request since it will be handled by the server.
processed automatically by the server.

View File

@ -17,8 +17,7 @@ def run_llava():
image_feature_size=576, image_feature_size=576,
) )
prompt = "<image>" * 576 + ( prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
"\nUSER: What is the content of this image?\nASSISTANT:")
image = Image.open("images/stop_sign.jpg") image = Image.open("images/stop_sign.jpg")

View File

@ -5,22 +5,17 @@ from PIL import Image
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# Dynamic image input is currently not supported and therefore
# a fixed image input shape and its corresponding feature size is required.
# See https://github.com/vllm-project/vllm/pull/4199 for the complete
# configuration matrix.
def run_llava_next(): def run_llava_next():
llm = LLM( llm = LLM(
model="llava-hf/llava-v1.6-mistral-7b-hf", model="llava-hf/llava-v1.6-mistral-7b-hf",
image_token_id=32000, image_token_id=32000,
image_input_shape="1,3,336,336", image_input_shape="1,3,336,336",
image_feature_size=1176, # Use the maximum possible value for memory profiling
image_feature_size=2928,
) )
prompt = "[INST] " + "<image>" * 1176 + ( prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
"\nWhat is shown in this image? [/INST]")
url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg" url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
image = Image.open(BytesIO(requests.get(url).content)) image = Image.open(BytesIO(requests.get(url).content))
sampling_params = SamplingParams(temperature=0.8, sampling_params = SamplingParams(temperature=0.8,

View File

@ -5,6 +5,9 @@ from PIL import Image
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them
def run_phi3v(): def run_phi3v():
model_path = "microsoft/Phi-3-vision-128k-instruct" model_path = "microsoft/Phi-3-vision-128k-instruct"
@ -18,7 +21,8 @@ def run_phi3v():
trust_remote_code=True, trust_remote_code=True,
image_token_id=32044, image_token_id=32044,
image_input_shape="1,3,1008,1344", image_input_shape="1,3,1008,1344",
image_feature_size=1921, # Use the maximum possible value for memory profiling
image_feature_size=2653,
max_num_seqs=5, max_num_seqs=5,
) )
@ -26,8 +30,6 @@ def run_phi3v():
# single-image prompt # single-image prompt
prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501 prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501
prompt = prompt.replace("<|image_1|>", "<|image|>" * 1921 + "<s>")
sampling_params = SamplingParams(temperature=0, max_tokens=64) sampling_params = SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate( outputs = llm.generate(

View File

@ -1,12 +1,13 @@
import contextlib import contextlib
import gc import gc
import os import os
import sys
from collections import UserList from collections import UserList
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict,
TypedDict, TypeVar) TypeVar)
import pytest import pytest
import torch import torch
@ -22,13 +23,10 @@ from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel) destroy_model_parallel)
from vllm.inputs import TextPrompt from vllm.inputs import TextPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal.utils import fetch_image
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu from vllm.utils import cuda_device_count_stateless, is_cpu
if TYPE_CHECKING:
# it will call torch.cuda.device_count()
from vllm.multimodal import MultiModalDataDict
logger = init_logger(__name__) logger = init_logger(__name__)
_TEST_DIR = os.path.dirname(__file__) _TEST_DIR = os.path.dirname(__file__)
@ -47,30 +45,42 @@ def _read_prompts(filename: str) -> List[str]:
@dataclass(frozen=True) @dataclass(frozen=True)
class ImageAsset: class ImageAsset:
name: Literal["stop_sign", "cherry_blossom"] name: Literal["stop_sign", "cherry_blossom", "boardwalk"]
@cached_property @cached_property
def pil_image(self) -> Image.Image: def pil_image(self) -> Image.Image:
if self.name == "boardwalk":
return fetch_image(
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
)
return Image.open(_IMAGE_DIR / f"{self.name}.jpg") return Image.open(_IMAGE_DIR / f"{self.name}.jpg")
def for_hf(self) -> Image.Image:
return self.pil_image
def for_vllm(self) -> Dict[str, Any]:
return {"image": self.pil_image}
class _ImageAssetPrompts(TypedDict): class _ImageAssetPrompts(TypedDict):
stop_sign: str stop_sign: str
cherry_blossom: str cherry_blossom: str
boardwalk: str
class _ImageAssets(UserList): if sys.version_info < (3, 9):
# UserList cannot be subscripted
class _ImageAssetsBase(UserList):
pass
else:
class _ImageAssetsBase(UserList[ImageAsset]):
pass
class _ImageAssets(_ImageAssetsBase):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__( super().__init__([
[ImageAsset("stop_sign"), ImageAsset("stop_sign"),
ImageAsset("cherry_blossom")]) ImageAsset("cherry_blossom"),
ImageAsset("boardwalk")
])
def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
""" """
@ -79,7 +89,10 @@ class _ImageAssets(UserList):
The order of the returned prompts matches the order of the The order of the returned prompts matches the order of the
assets when iterating through this object. assets when iterating through this object.
""" """
return [prompts["stop_sign"], prompts["cherry_blossom"]] return [
prompts["stop_sign"], prompts["cherry_blossom"],
prompts["boardwalk"]
]
IMAGE_ASSETS = _ImageAssets() IMAGE_ASSETS = _ImageAssets()
@ -220,7 +233,7 @@ class HfRunner:
self, self,
prompts: List[str], prompts: List[str],
images: Optional[List[Image.Image]] = None, images: Optional[List[Image.Image]] = None,
**kwargs, **kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]: ) -> List[Tuple[List[List[int]], List[str]]]:
if images: if images:
assert len(prompts) == len(images) assert len(prompts) == len(images)
@ -255,7 +268,7 @@ class HfRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional[List[Image.Image]] = None, images: Optional[List[Image.Image]] = None,
**kwargs, **kwargs: Any,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts, outputs = self.generate(prompts,
do_sample=False, do_sample=False,
@ -291,19 +304,30 @@ class HfRunner:
self, self,
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional[List[Image.Image]] = None,
**kwargs: Any,
) -> List[List[torch.Tensor]]: ) -> List[List[torch.Tensor]]:
all_logprobs = [] all_logprobs: List[List[torch.Tensor]] = []
for prompt in prompts: for i, prompt in enumerate(prompts):
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
output = self.model.generate( output = self.model.generate(
self.wrap_device(input_ids), **self.wrap_device(inputs),
use_cache=True, use_cache=True,
do_sample=False, do_sample=False,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
output_hidden_states=True, output_hidden_states=True,
return_dict_in_generate=True, return_dict_in_generate=True,
**kwargs,
) )
seq_logprobs = [] seq_logprobs: List[torch.Tensor] = []
for hidden_states in output.hidden_states: for hidden_states in output.hidden_states:
last_hidden_states = hidden_states[-1][0] last_hidden_states = hidden_states[-1][0]
logits = torch.matmul( logits = torch.matmul(
@ -323,20 +347,32 @@ class HfRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
images: Optional[List[Image.Image]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = [] all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = [] all_output_ids: List[List[int]] = []
all_output_strs: List[str] = [] all_output_strs: List[str] = []
for prompt in prompts: for i, prompt in enumerate(prompts):
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
input_ids = inputs.input_ids
output = self.model.generate( output = self.model.generate(
self.wrap_device(input_ids), **self.wrap_device(inputs),
use_cache=True, use_cache=True,
do_sample=False, do_sample=False,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
output_hidden_states=True, output_hidden_states=True,
return_dict_in_generate=True, return_dict_in_generate=True,
**kwargs,
) )
seq_logprobs: List[torch.Tensor] = [] seq_logprobs: List[torch.Tensor] = []
@ -431,7 +467,7 @@ class VllmRunner:
self, self,
prompts: List[str], prompts: List[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
images: Optional[List["MultiModalDataDict"]] = None, images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[List[int]], List[str]]]: ) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None: if images is not None:
assert len(prompts) == len(images) assert len(prompts) == len(images)
@ -439,7 +475,7 @@ class VllmRunner:
inputs = [TextPrompt(prompt=prompt) for prompt in prompts] inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None: if images is not None:
for i, image in enumerate(images): for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = image inputs[i]["multi_modal_data"] = {"image": image}
req_outputs = self.model.generate(inputs, req_outputs = self.model.generate(inputs,
sampling_params=sampling_params) sampling_params=sampling_params)
@ -462,10 +498,19 @@ class VllmRunner:
self, self,
prompts: List[str], prompts: List[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None assert sampling_params.logprobs is not None
req_outputs = self.model.generate(prompts, if images is not None:
assert len(prompts) == len(images)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image}
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params) sampling_params=sampling_params)
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
for req_output in req_outputs: for req_output in req_outputs:
@ -480,7 +525,7 @@ class VllmRunner:
self, self,
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional[List["MultiModalDataDict"]] = None, images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images) outputs = self.generate(prompts, greedy_params, images=images)
@ -492,11 +537,14 @@ class VllmRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0, greedy_logprobs_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=num_logprobs) logprobs=num_logprobs)
outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params) outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images)
return [(output_ids, output_str, output_logprobs) return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs] for output_ids, output_str, output_logprobs in outputs]

View File

@ -30,9 +30,10 @@ else:
@pytest.mark.parametrize("tensor_parallel_size", [2]) @pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, def test_models(hf_runner, vllm_runner, image_assets,
tensor_parallel_size: int, dtype: str, tensor_parallel_size: int, dtype: str, max_tokens: int,
max_tokens: int) -> None: num_logprobs: int) -> None:
if cuda_device_count_stateless() < tensor_parallel_size: if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip( pytest.skip(
f"Need at least {tensor_parallel_size} GPUs to run the test.") f"Need at least {tensor_parallel_size} GPUs to run the test.")
@ -44,8 +45,10 @@ def test_models(hf_runner, vllm_runner, image_assets,
vllm_runner, vllm_runner,
image_assets, image_assets,
model_and_config=model_and_vl_config[0], model_and_config=model_and_vl_config[0],
size_factors=[1.0],
dtype=dtype, dtype=dtype,
max_tokens=max_tokens, max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
) )

View File

@ -4,18 +4,21 @@ import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig from vllm.config import VisionLanguageConfig
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_outputs_equal from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm pytestmark = pytest.mark.vlm
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "stop_sign":
"<image>\nUSER: What's the content of the image?\nASSISTANT:", "USER: <image>\nWhat's the content of the image?\nASSISTANT:",
"cherry_blossom": "cherry_blossom":
"<image>\nUSER: What is the season?\nASSISTANT:", "USER: <image>\nWhat is the season?\nASSISTANT:",
"boardwalk":
"USER: <image>\nWhat's in this image?\nASSISTANT:",
}) })
@ -37,27 +40,34 @@ model_and_vl_config = [
] ]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str], def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
vlm_config: VisionLanguageConfig, model_id: str): vlm_config: VisionLanguageConfig, model_id: str):
"""Sanitize vllm output to be comparable with hf output. """Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla". It also reduces `output_str` from "<image><image>bla" to "bla".
""" """
output_ids, output_str = vllm_output output_ids, output_str, out_logprobs = vllm_output
image_token_id = vlm_config.image_token_id image_token_id = vlm_config.image_token_id
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id) image_token_str = tokenizer.decode(image_token_id)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [ hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids) token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id if token_id != image_token_id or output_ids[idx - 1] != image_token_id
] ]
hf_output_str = output_str \ hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, "") .replace(image_token_str * vlm_config.image_feature_size, "")
assert hf_output_str[0] == " "
hf_output_str = hf_output_str[1:]
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str return hf_output_ids, hf_output_str, out_logprobs
def run_test( def run_test(
@ -66,8 +76,10 @@ def run_test(
image_assets: _ImageAssets, image_assets: _ImageAssets,
model_and_config: Tuple[str, VisionLanguageConfig], model_and_config: Tuple[str, VisionLanguageConfig],
*, *,
size_factors: List[float],
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int, tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None, distributed_executor_backend: Optional[str] = None,
): ):
@ -81,42 +93,49 @@ def run_test(
The text output is sanitized to be able to compare with hf. The text output is sanitized to be able to compare with hf.
""" """
model_id, vlm_config = model_and_config model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets] images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF. # NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization. # vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method). # will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model_id, with vllm_runner(model_id,
dtype=dtype, dtype=dtype,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enforce_eager=True, enforce_eager=True,
**vlm_config.as_cli_args_dict()) as vllm_model: **vlm_config.as_cli_args_dict()) as vllm_model:
vllm_outputs_per_image = [
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()` vllm_model.generate_greedy_logprobs(prompts,
# we must put it inside the vllm_runner context manager max_tokens,
# i.e. after creating vLLM instance. num_logprobs=num_logprobs,
vllm_images = [asset.for_vllm() for asset in image_assets] images=images)
for prompts, images in inputs_per_image
vllm_image_prompts = [
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
for p in HF_IMAGE_PROMPTS
] ]
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model: with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens, max_tokens,
images=hf_images) num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
check_outputs_equal( for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
hf_outputs, vllm_outputs_per_image):
[ # TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, vlm_config, model_id) vllm_to_hf_output(vllm_output, vlm_config, model_id)
for vllm_output in vllm_outputs for vllm_output in vllm_outputs
], ],
@ -126,16 +145,33 @@ def run_test(
@pytest.mark.parametrize("model_and_config", model_and_vl_config) @pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config, def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None: size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
run_test( run_test(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
image_assets, image_assets,
model_and_config, model_and_config,
size_factors=size_factors,
dtype=dtype, dtype=dtype,
max_tokens=max_tokens, max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1, tensor_parallel_size=1,
) )

View File

@ -1,12 +1,15 @@
from typing import List, Tuple import re
from typing import List, Optional, Tuple
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig from vllm.config import VisionLanguageConfig
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from ..conftest import IMAGE_ASSETS from ..conftest import IMAGE_ASSETS
from .utils import check_outputs_equal from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm pytestmark = pytest.mark.vlm
@ -15,21 +18,20 @@ _PREFACE = (
"The assistant gives helpful, detailed, and polite answers to the human's " "The assistant gives helpful, detailed, and polite answers to the human's "
"questions.") "questions.")
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "stop_sign":
f"{_PREFACE} <image>\nUSER: What's the content of the image?\nASSISTANT:", f"{_PREFACE} USER: <image>\nWhat's the content of the image? ASSISTANT:",
"cherry_blossom": "cherry_blossom":
f"{_PREFACE} <image>\nUSER: What is the season?\nASSISTANT:", f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
"boardwalk":
f"{_PREFACE} USER: <image>\nWhat's in this image? ASSISTANT:",
}) })
def iter_llava_next_configs(model_name: str): def iter_llava_next_configs(model_name: str):
# Need to use the max possible feature size for profile_run
image_hw_to_feature_size = { image_hw_to_feature_size = {
(336, 336): 1176, (336, 336): 2928,
(672, 672): 2928,
(1344, 336): 1944,
(336, 1344): 1890,
} }
for (h, w), f in image_hw_to_feature_size.items(): for (h, w), f in image_hw_to_feature_size.items():
@ -47,37 +49,55 @@ model_and_vl_config = [
] ]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str], def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
vlm_config: VisionLanguageConfig, model_id: str): vlm_config: VisionLanguageConfig, model_id: str):
"""Sanitize vllm output to be comparable with hf output. """Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla". It also reduces `output_str` from "<image><image>bla" to "bla".
""" """
output_ids, output_str = vllm_output output_ids, output_str, out_logprobs = vllm_output
image_token_id = vlm_config.image_token_id image_token_id = vlm_config.image_token_id
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id) image_token_str = tokenizer.decode(image_token_id)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [ hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids) token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id if token_id != image_token_id or output_ids[idx - 1] != image_token_id
] ]
hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, " ")
return hf_output_ids, hf_output_str hf_output_str = re.sub(fr"({image_token_str})+", "", output_str)
assert hf_output_str[0] == " "
hf_output_str = hf_output_str[1:]
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
@pytest.mark.parametrize("model_and_config", model_and_vl_config) @pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config, def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None: size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
"""Inference result should be the same between hf and vllm. """Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images. All the image fixtures for the test is under tests/images.
@ -88,34 +108,43 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
The text output is sanitized to be able to compare with hf. The text output is sanitized to be able to compare with hf.
""" """
model_id, vlm_config = model_and_config model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets] images = [asset.pil_image for asset in image_assets]
vllm_images = [asset.for_vllm() for asset in image_assets]
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model: inputs_per_image = [(
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, [prompt for _ in size_factors],
max_tokens, [rescale_image_size(image, factor) for factor in size_factors],
images=hf_images) ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
vllm_image_prompts = [ # max_model_len should be greater than image_feature_size
p.replace("<image>", "<image>" * vlm_config.image_feature_size) with vllm_runner(model_id,
for p in HF_IMAGE_PROMPTS
]
with vllm_runner(
model_id,
dtype=dtype, dtype=dtype,
# should be greater than image_feature_size
max_model_len=4096, max_model_len=4096,
enforce_eager=True, enforce_eager=True,
**vlm_config.as_cli_args_dict(), **vlm_config.as_cli_args_dict()) as vllm_model:
) as vllm_model: vllm_outputs_per_image = [
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, vllm_model.generate_greedy_logprobs(prompts,
max_tokens, max_tokens,
images=vllm_images) num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
check_outputs_equal( with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs, hf_outputs_per_image = [
[ hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, vlm_config, model_id) vllm_to_hf_output(vllm_output, vlm_config, model_id)
for vllm_output in vllm_outputs for vllm_output in vllm_outputs
], ],

View File

@ -1,29 +1,33 @@
import re
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig from vllm.config import VisionLanguageConfig
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu from vllm.utils import is_cpu
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_outputs_equal from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm pytestmark = pytest.mark.vlm
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "stop_sign":
"<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
"cherry_blossom": "cherry_blossom":
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", # noqa: E501 "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
"boardwalk":
"<|user|>\n<|image_1|>\nWhat's in this image?<|end|>\n<|assistant|>\n",
}) })
def iter_phi3v_configs(model_name: str): def iter_phi3v_configs(model_name: str):
# Need to use the max possible feature size for profile_run
image_hw_to_feature_size = { image_hw_to_feature_size = {
(1008, 1344): 1921, (1008, 1344): 2653,
(2016, 2688): 1933,
} }
for (h, w), f in image_hw_to_feature_size.items(): for (h, w), f in image_hw_to_feature_size.items():
@ -39,29 +43,29 @@ model_and_vl_config = [
] ]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str], def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
vlm_config: VisionLanguageConfig, model_id: str): vlm_config: VisionLanguageConfig, model_id: str):
"""Sanitize vllm output to be comparable with hf output. """Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla". It also reduces `output_str` from "<image><image>bla" to "bla".
""" """
output_ids, output_str = vllm_output output_ids, output_str, out_logprobs = vllm_output
image_token_id = vlm_config.image_token_id
tokenizer = AutoTokenizer.from_pretrained(model_id) output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str)
image_token_str = tokenizer.decode(image_token_id) assert output_str_without_image[0] == " "
output_str_without_image = output_str_without_image[1:]
hf_output_ids = [ hf_output_str = output_str_without_image.replace("<|user|>", "") \
token_id if token_id != image_token_id else 0
for idx, token_id in enumerate(output_ids)
]
hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, "") \
.replace("<s>", " ").replace("<|user|>", "") \
.replace("<|end|>\n<|assistant|>", " ") .replace("<|end|>\n<|assistant|>", " ")
return hf_output_ids, hf_output_str tokenizer = AutoTokenizer.from_pretrained(model_id)
hf_output_ids = tokenizer.encode(output_str_without_image)
assert hf_output_ids[0] == 1
hf_output_ids = hf_output_ids[1:]
return hf_output_ids, hf_output_str, out_logprobs
target_dtype = "half" target_dtype = "half"
@ -75,8 +79,10 @@ def run_test(
image_assets: _ImageAssets, image_assets: _ImageAssets,
model_and_config: Tuple[str, VisionLanguageConfig], model_and_config: Tuple[str, VisionLanguageConfig],
*, *,
size_factors: List[float],
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int, tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None, distributed_executor_backend: Optional[str] = None,
): ):
@ -90,49 +96,53 @@ def run_test(
The text output is sanitized to be able to compare with hf. The text output is sanitized to be able to compare with hf.
""" """
model_id, vlm_config = model_and_config model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets] images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF. # NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization. # vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method). # will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model_id, with vllm_runner(model_id,
max_model_len=2048, max_model_len=4096,
dtype=dtype, dtype=dtype,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
enforce_eager=True,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
**vlm_config.as_cli_args_dict()) as vllm_model: **vlm_config.as_cli_args_dict()) as vllm_model:
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()` vllm_outputs_per_image = [
# we must put it inside the vllm_runner context manager vllm_model.generate_greedy_logprobs(prompts,
# i.e. after creating vLLM instance.
vllm_images = [asset.for_vllm() for asset in image_assets]
vllm_image_prompts = [
p.replace("<|image_1|>",
"<|image|>" * vlm_config.image_feature_size + "<s>")
for p in HF_IMAGE_PROMPTS
]
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs,
images=vllm_images) images=vllm_images)
for prompts, vllm_images in inputs_per_image
]
# use eager mode for hf runner, since phi3_v didn't work with flash_attn # use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"} hf_model_kwargs = {"_attn_implementation": "eager"}
with hf_runner(model_id, dtype=dtype, with hf_runner(model_id, dtype=dtype,
model_kwargs=hf_model_kwargs) as hf_model: model_kwargs=hf_model_kwargs) as hf_model:
hf_outputs = hf_model.generate_greedy( eos_token_id = hf_model.processor.tokenizer.eos_token_id
HF_IMAGE_PROMPTS, hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs,
images=hf_images, images=hf_images,
eos_token_id=hf_model.processor.tokenizer.eos_token_id) eos_token_id=eos_token_id)
for prompts, hf_images in inputs_per_image
]
check_outputs_equal( for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
hf_outputs, vllm_outputs_per_image):
[ check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, vlm_config, model_id) vllm_to_hf_output(vllm_output, vlm_config, model_id)
for vllm_output in vllm_outputs for vllm_output in vllm_outputs
], ],
@ -141,22 +151,36 @@ def run_test(
) )
# Since we use _attn_implementation="eager" for hf_runner, here is # Since we use _attn_implementation="eager" for hf_runner, there is more
# numeric difference for longer context and test can't pass # significant numerical difference. The basic `logprobs=5` fails to pass.
@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
@pytest.mark.parametrize("model_and_config", model_and_vl_config) @pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config, def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None: size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
run_test( run_test(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
image_assets, image_assets,
model_and_config, model_and_config,
size_factors=size_factors,
dtype=dtype, dtype=dtype,
max_tokens=max_tokens, max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1, tensor_parallel_size=1,
) )

View File

@ -1,11 +1,18 @@
from typing import Dict, List, Tuple import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import SampleLogprobs
TokensText = Tuple[List[int], str] TokensText = Tuple[List[int], str]
def check_outputs_equal(outputs_0_lst: List[TokensText], def check_outputs_equal(
outputs_1_lst: List[TokensText], name_0: str, *,
name_1: str): outputs_0_lst: Sequence[TokensText],
outputs_1_lst: Sequence[TokensText],
name_0: str,
name_1: str,
):
""" """
Compare the two sequences generated by different models, Compare the two sequences generated by different models,
which should be equal. which should be equal.
@ -18,20 +25,28 @@ def check_outputs_equal(outputs_0_lst: List[TokensText],
output_ids_0, output_str_0 = outputs_0 output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1 output_ids_1, output_str_1 = outputs_1
assert output_str_0 == output_str_1, (f"Test{prompt_idx}:" # The text and token outputs should exactly match
f"\n{name_0}:\t{output_str_0!r}" fail_msg = (f"Test{prompt_idx}:"
f"\n{name_1}:\t{output_str_1!r}")
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}" f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}") f"\n{name_1}:\t{output_str_1!r}")
assert output_str_0 == output_str_1, fail_msg
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] assert output_ids_0 == output_ids_1, fail_msg
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
outputs_1_lst: List[TokensTextLogprobs], name_0: str, float]],
name_1: str): SampleLogprobs]]]
def check_logprobs_close(
*,
outputs_0_lst: Sequence[TokensTextLogprobs],
outputs_1_lst: Sequence[TokensTextLogprobs],
name_0: str,
name_1: str,
warn_on_mismatch: bool = True,
):
""" """
Compare the logprobs of two sequences generated by different models, Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal. which should be similar but not necessarily equal.
@ -45,21 +60,52 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
output_ids_0, output_str_0, logprobs_0 = outputs_0 output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1 output_ids_1, output_str_1, logprobs_1 = outputs_1
if logprobs_0 is None:
logprobs_0 = [None] * len(output_ids_0)
if logprobs_1 is None:
logprobs_1 = [None] * len(output_ids_1)
# Loop through generated tokens. # Loop through generated tokens.
for idx, (output_id_0, for idx, (output_id_0,
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
# If generated tokens don't match, then # If generated tokens don't match, then
if output_id_0 != output_id_1: if output_id_0 != output_id_1:
logprobs_elem_0 = logprobs_0[idx]
logprobs_elem_1 = logprobs_1[idx]
# Each predicted token must be in top N logprobs of the other # Each predicted token must be in top N logprobs of the other
assert output_id_0 in logprobs_1[idx], ( fail_msg = (
f"Test{prompt_idx}:" f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}" f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
f"\n{name_1}:\t{output_str_1!r}") f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
assert output_id_1 in logprobs_0[idx], (
f"Test{prompt_idx}:" assert logprobs_elem_0 is not None, fail_msg
f"\n{name_0}:\t{output_str_0!r}" assert logprobs_elem_1 is not None, fail_msg
f"\n{name_1}:\t{output_str_1!r}") assert output_id_0 in logprobs_elem_1, fail_msg
assert output_id_1 in logprobs_elem_0, fail_msg
if warn_on_mismatch:
with warnings.catch_warnings():
# This ensures that repeated warnings are shown
# in the output, not just the first occurrence
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
# Break out since sequences will now diverge. # Break out since sequences will now diverge.
break break
else:
if output_str_0 != output_str_1 and warn_on_mismatch:
# The token outputs exactly match,
# so the text outputs should exactly match as well
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
with warnings.catch_warnings():
# This ensures that repeated warnings are shown
# in the output, not just the first occurrence
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)

View File

@ -4,12 +4,12 @@ from transformers import CLIPImageProcessor, LlavaNextImageProcessor
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import rescale_image_size
from ..conftest import _STR_DTYPE_TO_TORCH_DTYPE
@pytest.mark.parametrize("dtype", ["half", "float"]) @pytest.mark.parametrize("dtype", ["half", "float"])
def test_clip_image_processor(image_assets, dtype): @pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_clip_image_processor(image_assets, dtype, size_factor):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf" MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
@ -26,13 +26,15 @@ def test_clip_image_processor(image_assets, dtype):
) )
for asset in image_assets: for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor)
hf_result = hf_processor.preprocess( hf_result = hf_processor.preprocess(
asset.pil_image, image,
return_tensors="pt", return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype]) )
vllm_result = MULTIMODAL_REGISTRY.map_input( vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config, model_config,
{"image": asset.pil_image}, {"image": image},
) )
assert hf_result.keys() == vllm_result.keys() assert hf_result.keys() == vllm_result.keys()
@ -44,12 +46,10 @@ def test_clip_image_processor(image_assets, dtype):
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
@pytest.mark.parametrize("dtype", ["half", "float"]) @pytest.mark.parametrize("dtype", ["half", "float"])
def test_llava_next_image_processor(image_assets, dtype): @pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
MODEL_NAME = "llava-hf/llava-v1.6-34b-hf" def test_llava_next_image_processor(image_assets, dtype, size_factor):
MODEL_NAME = "llava-hf/llava-v1.6-vicuna-7b-hf"
hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME) hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
assert isinstance(hf_processor, LlavaNextImageProcessor) assert isinstance(hf_processor, LlavaNextImageProcessor)
@ -65,13 +65,15 @@ def test_llava_next_image_processor(image_assets, dtype):
) )
for asset in image_assets: for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor)
hf_result = hf_processor.preprocess( hf_result = hf_processor.preprocess(
asset.pil_image, image,
return_tensors="pt", return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype]) )
vllm_result = MULTIMODAL_REGISTRY.map_input( vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config, model_config,
{"image": asset.pil_image}, {"image": image},
) )
assert hf_result.keys() == vllm_result.keys() assert hf_result.keys() == vllm_result.keys()
@ -81,36 +83,3 @@ def test_llava_next_image_processor(image_assets, dtype):
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
@pytest.mark.xfail(
reason="Example image pixels were not processed using HuggingFace")
@pytest.mark.parametrize("dtype", ["float"])
def test_image_pixel_types(image_assets, dtype):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
model_config = ModelConfig(
model=MODEL_NAME,
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype=dtype,
revision=None,
)
for asset in image_assets:
image_result = MULTIMODAL_REGISTRY.map_input(
model_config,
{"image": asset.pil_image},
)
tensor_result = MULTIMODAL_REGISTRY.map_input(
model_config,
{"image": asset.pil_image},
)
assert image_result.keys() == tensor_result.keys()
for key, image_arr in image_result.items():
tensor_arr: np.ndarray = tensor_result[key].numpy()
assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}"
assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}"

View File

@ -5,10 +5,9 @@ from typing import Dict, Tuple
import numpy as np import numpy as np
import pytest import pytest
import pytest_asyncio
from PIL import Image from PIL import Image
from vllm.multimodal.utils import ImageFetchAiohttp from vllm.multimodal.utils import ImageFetchAiohttp, fetch_image
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [ TEST_IMAGE_URLS = [
@ -19,12 +18,9 @@ TEST_IMAGE_URLS = [
] ]
@pytest_asyncio.fixture(scope="session") @pytest.fixture(scope="module")
async def url_images() -> Dict[str, Image.Image]: def url_images() -> Dict[str, Image.Image]:
return { return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS}
image_url: await ImageFetchAiohttp.fetch_image(image_url)
for image_url in TEST_IMAGE_URLS
}
def get_supported_suffixes() -> Tuple[str, ...]: def get_supported_suffixes() -> Tuple[str, ...]:
@ -41,7 +37,15 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
return (np.asarray(a) == np.asarray(b.convert(a.mode))).all() return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()
@pytest.mark.asyncio @pytest.mark.asyncio(scope="module")
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_http(image_url: str):
image_sync = fetch_image(image_url)
image_async = await ImageFetchAiohttp.fetch_image(image_url)
assert _image_equals(image_sync, image_async)
@pytest.mark.asyncio(scope="module")
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@pytest.mark.parametrize("suffix", get_supported_suffixes()) @pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: Dict[str, Image.Image], async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
@ -68,8 +72,11 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
base64_image = base64.b64encode(f.read()).decode("utf-8") base64_image = base64.b64encode(f.read()).decode("utf-8")
data_url = f"data:{mime_type};base64,{base64_image}" data_url = f"data:{mime_type};base64,{base64_image}"
data_image = await ImageFetchAiohttp.fetch_image(data_url) data_image_sync = fetch_image(data_url)
if _image_equals(url_image, Image.open(f)): if _image_equals(url_image, Image.open(f)):
assert _image_equals(url_image, data_image) assert _image_equals(url_image, data_image_sync)
else: else:
pass # Lossy format; only check that image can be opened pass # Lossy format; only check that image can be opened
data_image_async = await ImageFetchAiohttp.fetch_image(data_url)
assert _image_equals(data_image_sync, data_image_async)

View File

@ -5,7 +5,7 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
Union) Union)
import torch import torch
from transformers import PretrainedConfig, PreTrainedTokenizerBase from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
@ -1303,16 +1303,6 @@ class VisionLanguageConfig:
image_input_shape: tuple image_input_shape: tuple
image_feature_size: int image_feature_size: int
#TODO(ywang96): make this a cached property once we refactor the
# VisionLanguageConfig class.
def get_image_token_text(
self, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]:
"""Get the image token placeholder text to be inserted into the
text prompt and the string representation of the image token id.
"""
image_token_str = tokenizer.decode(self.image_token_id)
return image_token_str * self.image_feature_size, image_token_str
def as_cli_args_dict(self) -> Dict[str, Any]: def as_cli_args_dict(self) -> Dict[str, Any]:
"""Flatten vision language config to pure args. """Flatten vision language config to pure args.

View File

@ -1,6 +1,7 @@
import codecs import codecs
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cached_property
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
List, Optional) List, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
@ -10,7 +11,7 @@ from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartImageParam, from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam) ChatCompletionContentPartTextParam)
from vllm.config import ModelConfig, VisionLanguageConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionContentPartParam, ChatCompletionLogProb, ChatCompletionContentPartParam, ChatCompletionLogProb,
@ -27,8 +28,7 @@ from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_image, from vllm.multimodal.utils import async_get_and_parse_image
get_full_image_text_prompt)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
@ -97,6 +97,36 @@ class OpenAIServingChat(OpenAIServing):
logger.warning( logger.warning(
"No chat template provided. Chat API will not work.") "No chat template provided. Chat API will not work.")
@cached_property
def image_token_str(self) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = self.model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv",
"paligemma"):
# These models do not use image tokens in the prompt
return None
# The default behaviour assumes that the image token is
# available to the tokenizer.
# (Suitable for LLaVA, Idefics2, DeepSeek-VL)
vlm_config = self.model_config.multimodal_config
if vlm_config is None:
raise ValueError(
"'image_url' input is not supported as the loaded "
"model is not multimodal.")
image_token_id = vlm_config.image_token_id
if vlm_config.image_token_id is None:
raise ValueError(
"'image_url' input is not supported as the loaded "
"model does not specify an image token.")
return self.tokenizer.decode(image_token_id)
def _parse_chat_message_content_parts( def _parse_chat_message_content_parts(
self, self,
role: str, role: str,
@ -105,21 +135,26 @@ class OpenAIServingChat(OpenAIServing):
texts: List[str] = [] texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = []
vlm_config: Optional[VisionLanguageConfig] = getattr(
self.engine.engine, "vision_language_config", None)
model_config = getattr(self.engine.engine, "model_config", None)
for part in parts: for part in parts:
part_type = part["type"] part_type = part["type"]
if part_type == "text": if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"] text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text) texts.append(text)
elif part_type == "image_url": elif part_type == "image_url":
if vlm_config is None: if len(mm_futures) > 0:
raise ValueError( raise NotImplementedError(
"'image_url' input is not supported as the loaded " "Multiple 'image_url' input is currently not supported."
"model is not multimodal.") )
assert self.tokenizer is not None
image_token_str = self.image_token_str
if image_token_str is not None:
if any(image_token_str in text for text in texts):
logger.warning(
"Detected image token string in the text prompt. "
"Skipping prompt formatting.")
else:
texts.append(image_token_str)
image_url = cast(ChatCompletionContentPartImageParam, image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"] part)["image_url"]
@ -128,42 +163,12 @@ class OpenAIServingChat(OpenAIServing):
"'image_url.detail' is currently not supported and " "'image_url.detail' is currently not supported and "
"will be ignored.") "will be ignored.")
mm_future = async_get_and_parse_image(image_url["url"]) image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(mm_future) mm_futures.append(image_future)
else: else:
raise NotImplementedError(f"Unknown part type: {part_type}") raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts) text_prompt = "\n".join(texts)
if vlm_config is not None and len(mm_futures):
assert len(
mm_futures
) == 1, "Multiple 'image_url' input is currently not supported."
(image_token_prompt,
image_token_str) = vlm_config.get_image_token_text(self.tokenizer)
# NOTE: If image token string (e.g, <image>) is already present
# in the text prompt, we assume it follows the same format required
# by the engine.
if image_token_str in text_prompt:
logger.warning(
"Detected image token string in the text prompt. "
"Skipping prompt formatting.")
messages = [
ConversationMessage(role=role, content=text_prompt)
]
else:
full_prompt = get_full_image_text_prompt(
image_prompt=image_token_prompt,
text_prompt=text_prompt,
config=model_config)
messages = [
ConversationMessage(role=role, content=full_prompt)
]
else:
messages = [ConversationMessage(role=role, content=text_prompt)] messages = [ConversationMessage(role=role, content=text_prompt)]
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
@ -267,7 +272,7 @@ class OpenAIServingChat(OpenAIServing):
"prompt": prompt_text, "prompt": prompt_text,
"prompt_token_ids": prompt_ids, "prompt_token_ids": prompt_ids,
} }
if mm_data is not None: if mm_data:
inputs["multi_modal_data"] = mm_data inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled() is_tracing_enabled = await self.engine.is_tracing_enabled()

View File

@ -36,6 +36,7 @@ class OpenAIServing:
super().__init__() super().__init__()
self.engine = engine self.engine = engine
self.model_config = model_config
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
# A separate tokenizer to map token IDs to strings. # A separate tokenizer to map token IDs to strings.

View File

@ -140,7 +140,8 @@ class InputRegistry:
The model is identified by ``model_config``. The model is identified by ``model_config``.
TODO: Add guide [ref: PR #5276] See also:
:ref:`adding_a_new_multimodal_model`
""" """
# Avoid circular import # Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture from vllm.model_executor.model_loader import get_model_architecture

View File

@ -8,10 +8,14 @@ from PIL import Image
from transformers import CLIPVisionConfig from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPAttention from transformers.models.clip.modeling_clip import CLIPAttention
from vllm.config import ModelConfig
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
@ -64,6 +68,39 @@ def dummy_image_for_clip(
return {"image": image} return {"image": image}
def input_processor_for_clip(
model_config: ModelConfig,
hf_config: CLIPVisionConfig,
llm_inputs: LLMInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
repeat_count=image_feature_size,
)
# NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module): class CLIPVisionEmbeddings(nn.Module):

View File

@ -6,7 +6,7 @@ from transformers import CLIPVisionConfig, LlavaConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -20,8 +20,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
from .interfaces import SupportsVision from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
_KEYS_TO_MODIFY_MAPPING = { _KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head", "language_model.lm_head": "lm_head",
@ -51,28 +53,10 @@ class LlavaMultiModalProjector(nn.Module):
return hidden_states return hidden_states
def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int) -> torch.Tensor:
"""In place merges in vision_embeddings with inputs_embeds."""
mask = (input_ids == image_token_id)
image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
if mask.sum() != image_feature_size:
raise ValueError(f"image_feature_size should be {image_feature_size}, "
f"but found: {mask.sum()}")
inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
vision_embeddings.shape[-1])
return inputs_embeds
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)""" """Shape: `(batch_size, num_channels, height, width)`"""
LlavaImageInputs = LlavaImagePixelInputs LlavaImageInputs = LlavaImagePixelInputs
@ -96,8 +80,30 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
raise NotImplementedError(msg) raise NotImplementedError(msg)
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision): class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self, def __init__(self,
@ -112,7 +118,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config.vision_config) self.vision_tower = CLIPVisionModel(config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector( self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,

View File

@ -1,4 +1,4 @@
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -10,7 +10,7 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -21,13 +21,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_patch_grid_length) get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsVision from .interfaces import SupportsVision
from .llava import LlavaMultiModalProjector, merge_vision_embeddings from .llava import LlavaMultiModalProjector
from .utils import merge_vision_embeddings
logger = init_logger(__name__) logger = init_logger(__name__)
@ -39,16 +40,27 @@ _KEYS_TO_MODIFY_MAPPING = {
class LlavaNextImagePixelInputs(TypedDict): class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: BatchedTensors
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)""" """
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch.
"""
image_sizes: NotRequired[torch.Tensor] image_sizes: NotRequired[torch.Tensor]
"""Shape: (batch_size, 2)""" """
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
LlavaNextImageInputs = LlavaNextImagePixelInputs LlavaNextImageInputs = LlavaNextImagePixelInputs
# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
# NOTE: new_height and new_width are further incremented to properly invert the
# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
def _get_llava_next_num_unpadded_features( def _get_llava_next_num_unpadded_features(
height: int, height: int,
width: int, width: int,
@ -56,7 +68,6 @@ def _get_llava_next_num_unpadded_features(
num_patch_height: int, num_patch_height: int,
num_patch_width: int, num_patch_width: int,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
# Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111
current_height = npatches * num_patch_height current_height = npatches * num_patch_height
current_width = npatches * num_patch_width current_width = npatches * num_patch_width
@ -64,9 +75,13 @@ def _get_llava_next_num_unpadded_features(
current_aspect_ratio: float = current_width / current_height current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio: if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width new_height = (height * current_width) // width
if new_height % 2 == 1:
new_height += 1
current_height = new_height current_height = new_height
else: else:
new_width = (width * current_height) // height new_width = (width * current_height) // height
if new_width % 2 == 1:
new_width += 1
current_width = new_width current_width = new_width
unpadded_features = current_height * current_width unpadded_features = current_height * current_width
@ -74,7 +89,8 @@ def _get_llava_next_num_unpadded_features(
return (unpadded_features, newline_features) return (unpadded_features, newline_features)
def _get_llava_next_image_feature_size( # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
def get_llava_next_image_feature_size(
hf_config: LlavaNextConfig, hf_config: LlavaNextConfig,
*, *,
input_height: int, input_height: int,
@ -89,7 +105,9 @@ def _get_llava_next_image_feature_size(
) )
base_feature_size = num_patches * num_patches base_feature_size = num_patches * num_patches
num_patch_height, num_patch_width = get_anyres_image_grid_shape( # Note: We follow the "wrong" width/height order
# [ref: PR huggingface/transformers#31588]
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size=(input_height, input_width), image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints, grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size, patch_size=vision_config.image_size,
@ -110,14 +128,16 @@ def _get_llava_next_image_feature_size(
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(LlavaNextConfig) hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
#TODO: change the logic for dummy data to support dynamic shape # Result in the max possible feature size (2x2 grid of 336x336px tiles)
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape dummy_height = dummy_width = 448
image_feature_size = _get_llava_next_image_feature_size( image_feature_size = get_llava_next_image_feature_size(
hf_config, input_height=dummy_height, input_width=dummy_width) hf_config,
input_height=dummy_height,
input_width=dummy_width,
)
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip( seq_data = dummy_seq_data_for_clip(
@ -139,27 +159,47 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]: def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
if isinstance(image, Image.Image): model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
# Temporary patch before dynamic number of image tokens is supported image_data = multi_modal_data["image"]
_, _, h, w = ctx.get_multimodal_config().image_input_shape if isinstance(image_data, Image.Image):
if (w, h) != (image.width, image.height): width, height = image_data.size
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
image = image.resize((w, h)) image_feature_size = get_llava_next_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
return MULTIMODAL_REGISTRY._get_plugin("image") \ vision_config = hf_config.vision_config
._default_input_mapper(ctx, image)
raise TypeError(f"Invalid type for 'image': {type(image)}") if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper) @MULTIMODAL_REGISTRY.register_image_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self, def __init__(self,
@ -172,8 +212,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self.config = config self.config = config
self.vlm_config = vlm_config self.vlm_config = vlm_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config=config.vision_config) self.vision_tower = CLIPVisionModel(config=config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector( self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
@ -196,24 +236,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self.image_newline = nn.Parameter( self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)) torch.empty(config.text_config.hidden_size))
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
_, num_channels, _, _ = self.vlm_config.image_input_shape
# Note that this is different from that of vLLM vision_language_config
# since the image is resized by the HuggingFace preprocessor
height = width = self.config.vision_config.image_size
if list(data.shape[2:]) != [num_channels, height, width]:
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"num_patches plus {[num_channels, height, width]}. "
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
return data
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]: if list(data.shape[1:]) != [2]:
raise ValueError( raise ValueError(
@ -223,14 +245,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return data return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImageInputs]: self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None) image_sizes = kwargs.pop("image_sizes", None)
if pixel_values is None or image_sizes is None: if pixel_values is None:
return None return None
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
@ -240,7 +262,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return LlavaNextImagePixelInputs( return LlavaNextImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_image_pixels(pixel_values), data=pixel_values,
image_sizes=self._validate_image_sizes(image_sizes), image_sizes=self._validate_image_sizes(image_sizes),
) )
@ -267,15 +289,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
strategy=self.config.vision_feature_select_strategy, strategy=self.config.vision_feature_select_strategy,
) )
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
def _merge_image_patch_embeddings(self, image_size: torch.Tensor, def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
patch_embeddings: torch.Tensor, *, patch_embeddings: torch.Tensor, *,
strategy: str) -> torch.Tensor: strategy: str) -> torch.Tensor:
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
if strategy == "flat": if strategy == "flat":
return patch_embeddings.flatten(0, 1) return patch_embeddings.flatten(0, 1)
if strategy.startswith("spatial"): if strategy.startswith("spatial"):
orig_width, orig_height = image_size
height = width = self.config.vision_config.image_size \ height = width = self.config.vision_config.image_size \
// self.config.vision_config.patch_size // self.config.vision_config.patch_size
@ -289,13 +310,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
other_patch_embeds = patch_embeddings[1:] other_patch_embeds = patch_embeddings[1:]
# image_aspect_ratio == "anyres" # image_aspect_ratio == "anyres"
# Note: We follow the "wrong" width/height order
# [ref: PR huggingface/transformers#31588]
num_patch_width, num_patch_height = get_anyres_image_grid_shape( num_patch_width, num_patch_height = get_anyres_image_grid_shape(
(orig_width, orig_height), image_size,
self.config.image_grid_pinpoints, self.config.image_grid_pinpoints,
self.config.vision_config.image_size, self.config.vision_config.image_size,
) )
other_patch_embeds = other_patch_embeds \ other_patch_embeds = other_patch_embeds \
.view(num_patch_width, num_patch_height, height, width, -1) .view(num_patch_height, num_patch_width, height, width, -1)
if "unpad" in strategy: if "unpad" in strategy:
other_patch_embeds = other_patch_embeds \ other_patch_embeds = other_patch_embeds \
@ -333,44 +356,53 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
raise ValueError(f"Unexpected patch merge strategy: {strategy}") raise ValueError(f"Unexpected patch merge strategy: {strategy}")
def _process_image_pixels( def _process_image_pixels(
self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor: self,
inputs: LlavaNextImagePixelInputs,
) -> BatchedTensors:
assert self.vision_tower is not None assert self.vision_tower is not None
pixel_values = inputs["data"] pixel_values = inputs["data"]
if isinstance(pixel_values, torch.Tensor):
b, num_patches, c, h, w = pixel_values.shape b, num_patches, c, h, w = pixel_values.shape
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
stacked_patch_embeddings = self.multi_modal_projector(
stacked_image_features)
return stacked_patch_embeddings.view(
b, num_patches, *stacked_patch_embeddings.shape[1:])
num_patches_per_batch = [v.shape[0] for v in pixel_values]
stacked_pixel_values = torch.cat(pixel_values)
stacked_image_features = self._image_pixels_to_features( stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values) self.vision_tower, stacked_pixel_values)
return stacked_image_features.view(b, num_patches, return [
*stacked_image_features.shape[-2:]) self.multi_modal_projector(image_features) for image_features in
torch.split(stacked_image_features, num_patches_per_batch)
]
def _process_image_input( def _process_image_input(
self, image_input: LlavaNextImageInputs) -> torch.Tensor: self, image_input: LlavaNextImageInputs) -> BatchedTensors:
assert self.vision_tower is not None patch_embeddings = self._process_image_pixels(image_input)
image_features = self._process_image_pixels(image_input)
patch_embeddings = self.multi_modal_projector(image_features)
image_sizes = image_input.get("image_sizes") image_sizes = image_input.get("image_sizes")
if image_sizes is None: if image_sizes is None:
batch_size = image_input["data"].shape[0] batch_size = len(image_input["data"])
vision_config = self.config.vision_config vision_config = self.config.vision_config
default_width = default_height = vision_config.image_size default_height = default_width = vision_config.image_size
image_sizes = torch.as_tensor([[default_width, default_height] image_sizes = torch.as_tensor([[default_height, default_width]
for _ in range(batch_size)]) for _ in range(batch_size)])
merged_patch_embeddings = [ return [
self._merge_image_patch_embeddings(image_sizes[i], self._merge_image_patch_embeddings(image_sizes[i],
patch_features, patch_features_batch,
strategy="spatial_unpad") strategy="spatial_unpad")
for i, patch_features in enumerate(patch_embeddings) for i, patch_features_batch in enumerate(patch_embeddings)
] ]
return torch.stack(merged_patch_embeddings, dim=0)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -404,8 +436,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
input_ids: Flattened (concatenated) input_ids corresponding to a input_ids: Flattened (concatenated) input_ids corresponding to a
batch. batch.
pixel_values: The pixels in each grid patch for each input image. pixel_values: The pixels in each grid patch for each input image.
Expects a batch with shape `[1, num_patches, 3, 336, 336]`. Expects a batch with shape `[1, num_patches, 3, h, w]`.
image_sizes: The original `(width, height)` for each input image. image_sizes: The original `(height, width)` for each input image.
Expects a batch with shape `[1, 2]`. Expects a batch with shape `[1, 2]`.
See also: See also:

View File

@ -13,7 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict import re
from functools import lru_cache
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import numpy as np import numpy as np
import torch import torch
@ -22,8 +24,8 @@ from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -34,10 +36,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
from .interfaces import SupportsVision from .interfaces import SupportsVision
logger = init_logger(__name__) logger = init_logger(__name__)
@ -251,50 +255,22 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
class Phi3VImagePixelInputs(TypedDict): class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: BatchedTensors
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)""" """
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch.
"""
image_sizes: torch.Tensor image_sizes: torch.Tensor
"""Shape: (batch_size, 2)""" """
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
def _get_phi3v_image_feature_size( # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
*,
input_height: int,
input_width: int,
) -> int:
h, w = input_height, input_width
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L178
return (h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
#TODO: change the logic for dummy data to support dynamic shape
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
image_feature_size = _get_phi3v_image_feature_size(
input_height=dummy_height,
input_width=dummy_width,
)
seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
image_token_id=32044,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
return seq_data, mm_data
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
target_height = int(np.ceil(height / padding_unit) * padding_unit) target_height = int(np.ceil(height / padding_unit) * padding_unit)
top_padding = int((target_height - height) / 2) top_padding = int((target_height - height) / 2)
@ -304,7 +280,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
return padded_width, padded_height return padded_width, padded_height
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16): def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
transposed = False transposed = False
if width < height: if width < height:
@ -329,27 +305,133 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
return padded_width, padded_height return padded_width, padded_height
def _image_processor(ctx: InputContext, # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
image: object) -> Dict[str, torch.Tensor]: def get_phi3v_image_feature_size(
hf_config: PretrainedConfig,
*,
input_height: int,
input_width: int,
) -> int:
num_crops = getattr(hf_config, "num_crops", 16)
new_width, new_height = _calc_hd_transform_size(width=input_width,
height=input_height,
hd_num=num_crops)
if isinstance(image, Image.Image): return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
# Temporary patch before dynamic number of image tokens is supported + (new_height // 336 + 1) * 12
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != _calc_hd_transform_size(width=image.width,
height=image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin("image") \
._default_input_mapper(ctx, image)
raise TypeError(f"Invalid type for 'image': {type(image)}")
@MULTIMODAL_REGISTRY.register_image_input_mapper(_image_processor) def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
# Result in the max possible feature size (h:w = 16:1)
dummy_height, dummy_width = 8000, 50
image_feature_size = get_phi3v_image_feature_size(
ctx.get_hf_config(PretrainedConfig),
input_height=dummy_height,
input_width=dummy_width,
)
seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
image_token_id=32044,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
return seq_data, mm_data
# Reserve this function to also handle placeholders for additional images
# [ref: PR #5820]
@lru_cache
def _get_image_placeholder_token_ids(model_config: ModelConfig,
idx: int) -> List[int]:
assert idx > 0
tokenizer = cached_get_tokenizer(model_config.tokenizer)
# We need to get the token for "<", not "▁<"
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
a_token_id, = tokenizer.encode("a", add_special_tokens=False)
a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
f"a<|image_{idx}|>", add_special_tokens=False)
assert a_token_id == a_token_id_
return image_placeholder_token_ids
def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(PretrainedConfig)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
w, h = image_data.size
w, h = _calc_hd_transform_size(width=w, height=h)
image_feature_size = get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h)
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
prompt = llm_inputs.get("prompt")
if prompt is None:
new_prompt = None
else:
if prompt.count("<|image|>") > 0:
logger.warning("Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating <|image|> tokens.")
elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
logger.warning("Multiple image input is not supported yet, "
"so any extra image tokens will be treated "
"as plain text.")
new_prompt = prompt
prompt_token_ids = llm_inputs["prompt_token_ids"]
image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
new_token_ids: List[int] = []
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
new_token_ids.append(multimodal_config.image_token_id)
# No need to further scan the list since we only replace once
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
break
else:
new_token_ids.append(prompt_token_ids[i])
# NOTE: Create a defensive copy of the original inputs
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
return input_processor_for_clip(
model_config,
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
llm_inputs,
image_token_id=multimodal_config.image_token_id,
image_feature_size_override=image_feature_size,
)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision): class Phi3VForCausalLM(nn.Module, SupportsVision):
def __init__(self, def __init__(self,
@ -363,6 +445,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
self.vlm_config = vlm_config self.vlm_config = vlm_config
self.model = LlamaModel(config, cache_config, quant_config) self.model = LlamaModel(config, cache_config, quant_config)
# TODO: Optionally initializes this for supporting embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding( self.vision_embed_tokens = Phi3HDImageEmbedding(
vlm_config, config, self.model.embed_tokens) vlm_config, config, self.model.embed_tokens)
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
@ -376,13 +460,21 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None) image_sizes = kwargs.pop("image_sizes", None)
if pixel_values is not None and image_sizes is not None: if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return Phi3VImagePixelInputs(type="pixel_values", return Phi3VImagePixelInputs(type="pixel_values",
data=pixel_values, data=pixel_values,
image_sizes=image_sizes) image_sizes=image_sizes)
return None
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,

View File

@ -0,0 +1,41 @@
import torch
from vllm.multimodal import BatchedTensors
def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: BatchedTensors,
image_token_id: int) -> torch.Tensor:
"""
Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
Note:
This updates `inputs_embeds` in place.
"""
mask = (input_ids == image_token_id)
num_expected_tokens = mask.sum()
if isinstance(vision_embeddings, torch.Tensor):
batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
total_tokens = batch_size * batch_tokens
if num_expected_tokens != total_tokens:
expr = f"{batch_size} x {batch_tokens}"
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"image tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
else:
size_per_batch = [t.shape[0] for t in vision_embeddings]
total_tokens = sum(size_per_batch)
if num_expected_tokens != total_tokens:
expr = ' + '.join(map(str, size_per_batch))
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"image tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = torch.cat(vision_embeddings)
return inputs_embeds

View File

@ -1,4 +1,5 @@
from .base import MultiModalDataDict, MultiModalPlugin from .base import (BatchedTensors, MultiModalDataDict, MultiModalInputs,
MultiModalPlugin)
from .registry import MultiModalRegistry from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry() MULTIMODAL_REGISTRY = MultiModalRegistry()
@ -11,8 +12,10 @@ See also:
""" """
__all__ = [ __all__ = [
"BatchedTensors",
"MultiModalDataDict",
"MultiModalInputs",
"MultiModalPlugin", "MultiModalPlugin",
"MULTIMODAL_REGISTRY", "MULTIMODAL_REGISTRY",
"MultiModalRegistry", "MultiModalRegistry",
"MultiModalDataDict",
] ]

View File

@ -1,23 +1,90 @@
import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Type, from collections import UserDict, defaultdict
TypedDict, TypeVar, Union) from typing import (Any, Callable, Dict, List, Optional, Type, TypedDict,
TypeVar, Union)
import torch
import torch.types
from PIL import Image
from torch import nn
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import InputContext from vllm.inputs import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING:
import torch
from PIL import Image
from torch import nn
logger = init_logger(__name__) logger = init_logger(__name__)
N = TypeVar("N", bound=Type["nn.Module"]) BatchedTensors = Union[torch.Tensor, List[torch.Tensor]]
"""
If each input tensor in the batch has the same size, this is a single batched
tensor; otherwise, this is a list of tensors with one element per batch.
"""
if sys.version_info < (3, 9):
# UserDict cannot be subscripted
class _MultiModalInputsBase(UserDict):
pass
else:
class _MultiModalInputsBase(UserDict[str, torch.Tensor]):
pass
class MultiModalInputs(_MultiModalInputsBase):
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.
"""
@staticmethod
def try_concat(
tensors: List[torch.Tensor],
*,
device: torch.types.Device,
) -> BatchedTensors:
# Avoid initializing CUDA too early
import torch
unbatched_shape = tensors[0].shape[1:]
for tensor in tensors:
if tensor.shape[1:] != unbatched_shape:
return [
tensor.squeeze(0).to(device=device) for tensor in tensors
]
return torch.cat(tensors, dim=0).to(device=device)
@staticmethod
def batch(
inputs_list: List["MultiModalInputs"],
device: torch.types.Device,
) -> Dict[str, BatchedTensors]:
"""Batch multiple inputs together into a dictionary."""
if len(inputs_list) == 0:
return {}
keys = inputs_list[0].keys()
item_lists: Dict[str, List[torch.Tensor]] = defaultdict(list)
for inputs in inputs_list:
if inputs.keys() != keys:
msg = f"Inputs do not share the same keys ({keys})"
raise ValueError(msg)
for k, v in inputs.items():
item_lists[k].append(v)
return {
k: MultiModalInputs.try_concat(item_list, device=device)
for k, item_list in item_lists.items()
}
class MultiModalDataBuiltins(TypedDict, total=False): class MultiModalDataBuiltins(TypedDict, total=False):
image: "Image.Image" image: Image.Image
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]] MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
@ -29,12 +96,13 @@ to the model by the corresponding mapper. By default, the mapper of
the corresponding plugin with the same modality key is applied. the corresponding plugin with the same modality key is applied.
""" """
MultiModalInputMapper = Callable[[InputContext, object], Dict[str, MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
"torch.Tensor"]]
"""Return a dictionary to be passed as keyword arguments to """Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers.""" and processors in HuggingFace Transformers."""
N = TypeVar("N", bound=Type[nn.Module])
class MultiModalPlugin(ABC): class MultiModalPlugin(ABC):
""" """
@ -48,8 +116,7 @@ class MultiModalPlugin(ABC):
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._input_mappers: Dict[Type["nn.Module"], self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
MultiModalInputMapper] = {}
@abstractmethod @abstractmethod
def get_data_key(self) -> str: def get_data_key(self) -> str:
@ -60,7 +127,7 @@ class MultiModalPlugin(ABC):
@abstractmethod @abstractmethod
def _default_input_mapper(self, ctx: InputContext, def _default_input_mapper(self, ctx: InputContext,
data: object) -> Dict[str, "torch.Tensor"]: data: object) -> MultiModalInputs:
"""Return a dictionary to be passed as keyword arguments to """Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to :meth:`~torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers. tokenizers and processors in HuggingFace Transformers.
@ -80,6 +147,7 @@ class MultiModalPlugin(ABC):
See also: See also:
:ref:`input_processing_pipeline` :ref:`input_processing_pipeline`
:ref:`adding_a_new_multimodal_model`
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
@ -97,7 +165,7 @@ class MultiModalPlugin(ABC):
return wrapper return wrapper
def map_input(self, model_config: ModelConfig, def map_input(self, model_config: ModelConfig,
data: object) -> Dict[str, "torch.Tensor"]: data: object) -> MultiModalInputs:
""" """
Apply an input mapper to a data passed Apply an input mapper to a data passed
to the model, transforming the data into a dictionary of model inputs. to the model, transforming the data into a dictionary of model inputs.
@ -106,7 +174,8 @@ class MultiModalPlugin(ABC):
The model is identified by ``model_config``. The model is identified by ``model_config``.
TODO: Add guide [ref: PR #5276] See also:
:ref:`adding_a_new_multimodal_model`
""" """
# Avoid circular import # Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture from vllm.model_executor.model_loader import get_model_architecture

View File

@ -1,19 +1,102 @@
from functools import lru_cache from functools import lru_cache
from typing import Dict from typing import List, Optional, Tuple, TypeVar
import torch import torch
from PIL import Image from PIL import Image
from transformers import PreTrainedTokenizerBase
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_image_processor from vllm.transformers_utils.image_processor import get_image_processor
from vllm.transformers_utils.tokenizer import get_tokenizer
from .base import MultiModalPlugin from .base import MultiModalInputs, MultiModalPlugin
logger = init_logger(__name__) logger = init_logger(__name__)
cached_get_image_processor = lru_cache(get_image_processor) cached_get_image_processor = lru_cache(get_image_processor)
cached_get_tokenizer = lru_cache(get_tokenizer)
# Utilities for image input processors
_T = TypeVar("_T", str, int)
def repeat_and_pad_token(
token: _T,
*,
repeat_count: int = 1,
pad_token_left: Optional[_T] = None,
pad_token_right: Optional[_T] = None,
) -> List[_T]:
replacement = [token] * repeat_count
if pad_token_left is not None:
replacement = [pad_token_left] + replacement
if pad_token_right is not None:
replacement = replacement + [pad_token_right]
return replacement
def repeat_and_pad_image_tokens(
tokenizer: PreTrainedTokenizerBase,
prompt: Optional[str],
prompt_token_ids: List[int],
*,
image_token_id: int,
repeat_count: int = 1,
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]:
if prompt is None:
new_prompt = None
else:
image_token_str = tokenizer.decode(image_token_id)
pad_token_str_left = (None if pad_token_left is None else
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))
replacement_str = "".join(
repeat_and_pad_token(
image_token_str,
repeat_count=repeat_count,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
image_token_count = prompt.count(image_token_str)
# This is an arbitrary number to distinguish between the two cases
if image_token_count > 16:
logger.warning(
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating %s tokens.", image_token_str)
elif image_token_count > 1:
logger.warning("Multiple image input is not supported yet, "
"so any extra image tokens will be treated "
"as plain text.")
# The image tokens are removed to be consistent with HuggingFace
new_prompt = prompt.replace(image_token_str, replacement_str, 1)
new_token_ids: List[int] = []
for i, token in enumerate(prompt_token_ids):
if token == image_token_id:
replacement_ids = repeat_and_pad_token(
image_token_id,
repeat_count=repeat_count,
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
new_token_ids.extend(replacement_ids)
# No need to further scan the list since we only replace once
new_token_ids.extend(prompt_token_ids[i + 1:])
break
else:
new_token_ids.append(token)
return new_prompt, new_token_ids
class ImagePlugin(MultiModalPlugin): class ImagePlugin(MultiModalPlugin):
@ -27,7 +110,7 @@ class ImagePlugin(MultiModalPlugin):
trust_remote_code=model_config.trust_remote_code) trust_remote_code=model_config.trust_remote_code)
def _default_input_mapper(self, ctx: InputContext, def _default_input_mapper(self, ctx: InputContext,
data: object) -> Dict[str, torch.Tensor]: data: object) -> MultiModalInputs:
model_config = ctx.model_config model_config = ctx.model_config
if isinstance(data, Image.Image): if isinstance(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config) image_processor = self._get_hf_image_processor(model_config)
@ -35,10 +118,15 @@ class ImagePlugin(MultiModalPlugin):
raise RuntimeError("No HuggingFace processor is available" raise RuntimeError("No HuggingFace processor is available"
"to process the image object") "to process the image object")
try: try:
return image_processor.preprocess(data, return_tensors="pt") \ batch_data = image_processor \
.to(model_config.dtype).data .preprocess(data, return_tensors="pt") \
.data
except Exception: except Exception:
logger.error("Failed to process image (%s)", data) logger.error("Failed to process image (%s)", data)
raise raise
raise TypeError(f"Invalid type for 'image': {type(data)}") return MultiModalInputs(batch_data)
elif isinstance(data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
raise TypeError(f"Invalid image type: {type(data)}")

View File

@ -1,18 +1,17 @@
import functools import functools
from typing import Optional, Sequence, Type, TypeVar from typing import Dict, Optional, Sequence
from torch import nn import torch
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from .base import MultiModalDataDict, MultiModalInputMapper, MultiModalPlugin from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
MultiModalPlugin)
from .image import ImagePlugin from .image import ImagePlugin
logger = init_logger(__name__) logger = init_logger(__name__)
N = TypeVar("N", bound=Type[nn.Module])
class MultiModalRegistry: class MultiModalRegistry:
""" """
@ -61,7 +60,7 @@ class MultiModalRegistry:
return self.register_input_mapper("image", mapper) return self.register_input_mapper("image", mapper)
def _process_input(self, key: str, value: object, def _process_input(self, key: str, value: object,
model_config: ModelConfig): model_config: ModelConfig) -> MultiModalInputs:
plugin = self._plugins.get(key) plugin = self._plugins.get(key)
if plugin: if plugin:
return plugin.map_input(model_config, value) return plugin.map_input(model_config, value)
@ -93,16 +92,28 @@ class MultiModalRegistry:
""" """
return self.register_input_mapper("image", mapper) return self.register_input_mapper("image", mapper)
def map_input(self, model_config: ModelConfig, data: MultiModalDataDict): def map_input(self, model_config: ModelConfig,
data: MultiModalDataDict) -> MultiModalInputs:
""" """
Apply an input mapper to the data passed to the model. Apply an input mapper to the data passed to the model.
See :meth:`MultiModalPlugin.map_input` for more details. See :meth:`MultiModalPlugin.map_input` for more details.
""" """
result_list = [ merged_dict: Dict[str, torch.Tensor] = {}
self._process_input(k, v, model_config) for k, v in data.items()
] for data_key, data_value in data.items():
return {k: v for d in result_list for k, v in d.items()} input_dict = self._process_input(data_key, data_value,
model_config)
for input_key, input_tensor in input_dict.items():
if input_key in merged_dict:
raise ValueError(f"The input mappers (keys={set(data)}) "
f"resulted in a conflicting keyword "
f"argument to `forward()`: {input_key}")
merged_dict[input_key] = input_tensor
return MultiModalInputs(merged_dict)
def create_input_mapper(self, model_config: ModelConfig): def create_input_mapper(self, model_config: ModelConfig):
""" """

View File

@ -4,11 +4,56 @@ from typing import Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp import aiohttp
import requests
from PIL import Image from PIL import Image
from vllm.config import ModelConfig
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from vllm.multimodal.base import MultiModalDataDict from vllm.multimodal.base import MultiModalDataDict
from vllm.version import __version__ as VLLM_VERSION
def _validate_remote_url(url: str, *, name: str):
parsed_url = urlparse(url)
if parsed_url.scheme not in ["http", "https"]:
raise ValueError(f"Invalid '{name}': A valid '{name}' "
"must have scheme 'http' or 'https'.")
def _get_request_headers():
return {"User-Agent": f"vLLM/{VLLM_VERSION}"}
def _load_image_from_bytes(b: bytes):
image = Image.open(BytesIO(b))
image.load()
return image
def _load_image_from_data_url(image_url: str):
# Only split once and assume the second part is the base64 encoded image
_, image_base64 = image_url.split(",", 1)
return load_image_from_base64(image_base64)
def fetch_image(image_url: str) -> Image.Image:
"""Load PIL image from a url or base64 encoded openai GPT4V format"""
if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url")
headers = _get_request_headers()
with requests.get(url=image_url, headers=headers) as response:
response.raise_for_status()
image_raw = response.content
image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
return image
class ImageFetchAiohttp: class ImageFetchAiohttp:
@ -29,34 +74,31 @@ class ImageFetchAiohttp:
"""Load PIL image from a url or base64 encoded openai GPT4V format""" """Load PIL image from a url or base64 encoded openai GPT4V format"""
if image_url.startswith('http'): if image_url.startswith('http'):
parsed_url = urlparse(image_url) _validate_remote_url(image_url, name="image_url")
if parsed_url.scheme not in ["http", "https"]:
raise ValueError("Invalid 'image_url': A valid 'image_url' "
"must have scheme 'http' or 'https'.")
# Avoid circular import
from vllm import __version__ as VLLM_VERSION
client = cls.get_aiohttp_client() client = cls.get_aiohttp_client()
headers = {"User-Agent": f"vLLM/{VLLM_VERSION}"} headers = _get_request_headers()
async with client.get(url=image_url, headers=headers) as response: async with client.get(url=image_url, headers=headers) as response:
response.raise_for_status() response.raise_for_status()
image_raw = await response.read() image_raw = await response.read()
image = Image.open(BytesIO(image_raw)) image = _load_image_from_bytes(image_raw)
# Only split once and assume the second part is the base64 encoded image
elif image_url.startswith('data:image'): elif image_url.startswith('data:image'):
image = load_image_from_base64(image_url.split(',', 1)[1]) image = _load_image_from_data_url(image_url)
else: else:
raise ValueError( raise ValueError(
"Invalid 'image_url': A valid 'image_url' must start " "Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.") "with either 'data:image' or 'http'.")
image.load()
return image return image
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await ImageFetchAiohttp.fetch_image(image_url)
return {"image": image}
def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str: def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
"""Encode a pillow image to base64 format.""" """Encode a pillow image to base64 format."""
@ -69,26 +111,11 @@ def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
"""Load image from base64 format.""" """Load image from base64 format."""
return Image.open(BytesIO(base64.b64decode(image))) return _load_image_from_bytes(base64.b64decode(image))
# TODO(ywang96): move this to a model registry for preprocessing vision def rescale_image_size(image: Image.Image, size_factor: float) -> Image.Image:
# language prompts based on the model type. """Rescale the dimensions of an image by a constant factor."""
def get_full_image_text_prompt(image_prompt: str, text_prompt: str, new_width = int(image.width * size_factor)
config: ModelConfig) -> str: new_height = int(image.height * size_factor)
"""Combine image and text prompts for vision language model depending on return image.resize((new_width, new_height))
the model architecture."""
if config.hf_config.model_type in ("llava", "llava_next"):
full_prompt = f"{image_prompt}\n{text_prompt}"
elif config.hf_config.model_type == 'phi3_v':
full_prompt = f"{image_prompt}<s>\n{text_prompt}"
else:
raise ValueError(
f"Unsupported model type: {config.hf_config.model_type}")
return full_prompt
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await ImageFetchAiohttp.fetch_image(image_url)
return {"image": image}

View File

@ -457,7 +457,7 @@ class SequenceGroup:
return next(iter(self.seqs_dict.values())).prompt_token_ids return next(iter(self.seqs_dict.values())).prompt_token_ids
@property @property
def multi_modal_data(self) -> Optional["MultiModalDataDict"]: def multi_modal_data(self) -> "MultiModalDataDict":
# All sequences in the group should have the same multi-modal data. # All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence. # We use the multi-modal data of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).multi_modal_data return next(iter(self.seqs_dict.values())).multi_modal_data

View File

@ -1,9 +1,4 @@
from transformers import AutoImageProcessor from typing import cast
from transformers.image_processing_utils import BaseImageProcessor
from vllm.logger import init_logger
logger = init_logger(__name__)
def get_image_processor( def get_image_processor(
@ -11,10 +6,15 @@ def get_image_processor(
*args, *args,
trust_remote_code: bool = False, trust_remote_code: bool = False,
**kwargs, **kwargs,
) -> BaseImageProcessor: ):
"""Gets an image processor for the given model name via HuggingFace.""" """Gets an image processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor
try: try:
processor: BaseImageProcessor = AutoImageProcessor.from_pretrained( processor = AutoImageProcessor.from_pretrained(
processor_name, processor_name,
*args, *args,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
@ -34,4 +34,4 @@ def get_image_processor(
else: else:
raise e raise e
return processor return cast(BaseImageProcessor, processor)

View File

@ -1,6 +1,6 @@
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Type, Union)
import torch import torch
from torch import nn from torch import nn
@ -12,7 +12,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad
@ -40,7 +41,7 @@ class CPUModelInput(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]: self) -> Dict[str, Union[int, torch.Tensor]]:
@ -132,15 +133,14 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], Dict[ ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
str, torch.Tensor]]: Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_modal_kwargs_list: Dict[str, multi_modal_inputs_list: List[MultiModalInputs] = []
List[torch.Tensor]] = defaultdict(list)
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
@ -162,10 +162,9 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
input_positions.extend(list(range(computed_len, seq_len))) input_positions.extend(list(range(computed_len, seq_len)))
mm_data = seq_group_metadata.multi_modal_data mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None: if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data) mm_kwargs = self.multi_modal_input_mapper(mm_data)
for k, v in mm_kwargs.items(): multi_modal_inputs_list.append(mm_kwargs)
multi_modal_kwargs_list[k].append(v)
# Compute the slot mapping. # Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
@ -189,11 +188,6 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
multi_modal_kwargs = {
k: torch.cat(v, dim=0).to(self.device)
for k, v in multi_modal_kwargs_list.items()
}
num_prompt_tokens = len(input_tokens) num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens, input_tokens = torch.tensor(input_tokens,
@ -217,6 +211,10 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
block_tables=torch.tensor([]), block_tables=torch.tensor([]),
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, seq_lens, return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs) multi_modal_kwargs)
@ -367,10 +365,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
"positions": model_input.input_positions, "positions": model_input.input_positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"attn_metadata": model_input.attn_metadata, "attn_metadata": model_input.attn_metadata,
**(model_input.multi_modal_kwargs or {}),
} }
if (self.vision_language_config
and model_input.multi_modal_kwargs is not None):
execute_model_kwargs.update(model_input.multi_modal_kwargs)
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)

View File

@ -92,10 +92,9 @@ class EmbeddingModelRunner(
"positions": model_input.input_positions, "positions": model_input.input_positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"attn_metadata": model_input.attn_metadata, "attn_metadata": model_input.attn_metadata,
**(model_input.multi_modal_kwargs or {}),
} }
if self.vision_language_config:
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
execute_model_kwargs.update({"image_input": multi_modal_kwargs})
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)
# Only perform pooling in the driver worker. # Only perform pooling in the driver worker.

View File

@ -3,8 +3,8 @@ import gc
import time import time
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
TypeVar, Union) Tuple, Type, TypeVar, Union)
import numpy as np import numpy as np
import torch import torch
@ -37,7 +37,8 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import supports_lora from vllm.model_executor.models.interfaces import supports_lora
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
@ -83,7 +84,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
lora_mapping: Optional["LoRAMapping"] = None lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0 virtual_engine: int = 0
@ -356,8 +357,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
context_lens: List[int] = [] context_lens: List[int] = []
query_lens: List[int] = [] query_lens: List[int] = []
block_tables: List[List[int]] = [] block_tables: List[List[int]] = []
multi_modal_kwargs_list: Dict[str, multi_modal_inputs_list: List[MultiModalInputs] = []
List[torch.Tensor]] = defaultdict(list)
request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list) request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
decode_only = True decode_only = True
num_prefills = 0 num_prefills = 0
@ -528,8 +528,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if mm_data: if mm_data:
# Process multi-modal data # Process multi-modal data
mm_kwargs = self.multi_modal_input_mapper(mm_data) mm_kwargs = self.multi_modal_input_mapper(mm_data)
for k, v in mm_kwargs.items(): multi_modal_inputs_list.append(mm_kwargs)
multi_modal_kwargs_list[k].append(v)
is_profile_run = _is_block_tables_empty( is_profile_run = _is_block_tables_empty(
seq_group_metadata.block_tables) seq_group_metadata.block_tables)
@ -746,10 +745,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
else: else:
lora_mapping = None lora_mapping = None
multi_modal_kwargs = { multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
k: torch.cat(v, dim=0).to(self.device) device=self.device)
for k, v in multi_modal_kwargs_list.items()
}
request_ids_to_seq_ids = { request_ids_to_seq_ids = {
seq_group_metadata.request_id: seq_group_metadata.request_id:
list(seq_group_metadata.seq_data.keys()) list(seq_group_metadata.seq_data.keys())
@ -821,7 +818,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len) .dummy_data_for_profiling(model_config, seq_len)
assert len(seq_data.prompt_token_ids) == seq_len
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),

View File

@ -1,5 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Union)
import torch import torch
from torch import nn from torch import nn
@ -9,6 +10,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
@ -29,6 +32,7 @@ class ModelInputForNeuron(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
input_block_ids: Optional[torch.Tensor] = None input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: Optional["SamplingMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]: self) -> Dict[str, Union[int, torch.Tensor]]:
@ -65,6 +69,10 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self.device = self.device_config.device self.device = self.device_config.device
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # initialize after load_model. self.model: nn.Module # initialize after load_model.
@ -76,13 +84,15 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], Mapping[
str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
input_block_ids: List[int] = [] input_block_ids: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
@ -102,6 +112,12 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
assert len(block_table) == 1 assert len(block_table) == 1
input_block_ids.append(block_table[0]) input_block_ids.append(block_table[0])
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
# Process multi-modal data
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
max_seq_len = max(seq_lens) max_seq_len = max(seq_lens)
assert max_seq_len > 0 assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens, input_tokens = make_tensor_with_pad(input_tokens,
@ -118,7 +134,11 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
return input_tokens, input_positions, input_block_ids, seq_lens multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs)
def _prepare_decode( def _prepare_decode(
self, self,
@ -184,8 +204,9 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
is_prompt = seq_group_metadata_list[0].is_prompt is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors. # Prepare input tensors.
if is_prompt: if is_prompt:
(input_tokens, input_positions, input_block_ids, (input_tokens, input_positions, input_block_ids, seq_lens,
seq_lens) = self._prepare_prompt(seq_group_metadata_list) multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else: else:
(input_tokens, input_positions, (input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list) input_block_ids) = self._prepare_decode(seq_group_metadata_list)
@ -203,7 +224,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
return ModelInputForNeuron(input_tokens=input_tokens, return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,
input_block_ids=input_block_ids, input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata) sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
@ -221,6 +243,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids, input_block_ids=model_input.input_block_ids,
**(model_input.multi_modal_kwargs or {}),
) )
# Compute the logits. # Compute the logits.

View File

@ -1,4 +1,4 @@
from typing import List, NamedTuple, Optional, Tuple from typing import List, Mapping, NamedTuple, Optional, Tuple
import openvino as ov import openvino as ov
import torch import torch
@ -12,6 +12,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.openvino import get_model from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
@ -23,7 +25,7 @@ class ModelInput(NamedTuple):
attn_metadata: Optional[OpenVINOAttentionMetadata] attn_metadata: Optional[OpenVINOAttentionMetadata]
seq_lens: List[int] seq_lens: List[int]
query_lens: List[int] query_lens: List[int]
multi_modal_input: Optional[torch.Tensor] multi_modal_kwargs: Mapping[str, BatchedTensors]
@classmethod @classmethod
def empty(cls, device): def empty(cls, device):
@ -32,7 +34,7 @@ class ModelInput(NamedTuple):
attn_metadata=None, attn_metadata=None,
seq_lens=[], seq_lens=[],
query_lens=[], query_lens=[],
multi_modal_input=None) multi_modal_kwargs={})
class OpenVINOModelRunner: class OpenVINOModelRunner:
@ -78,6 +80,10 @@ class OpenVINOModelRunner:
self.block_size, self.block_size,
) )
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
@ -108,6 +114,8 @@ class OpenVINOModelRunner:
seq_lens: List[int] = [] seq_lens: List[int] = []
past_lens: List[int] = [] past_lens: List[int] = []
query_lens: List[int] = [] query_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
subsequence_begins: List[int] = [] subsequence_begins: List[int] = []
block_indices: List[int] = [] block_indices: List[int] = []
block_indices_begins: List[int] = [] block_indices_begins: List[int] = []
@ -160,6 +168,11 @@ class OpenVINOModelRunner:
and self.sliding_window is None and self.sliding_window is None
and is_prompt) and is_prompt)
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
# TODO(sang): Combine chunked prefill and prefix caching by # TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size. # only allowing multiple of block_size chunk size.
@ -251,22 +264,24 @@ class OpenVINOModelRunner:
block_indices_begins=block_indices_begins_tensor, block_indices_begins=block_indices_begins_tensor,
max_context_len=max_context_len_tensor, max_context_len=max_context_len_tensor,
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return ModelInput( return ModelInput(
input_tokens, input_tokens,
input_positions, input_positions,
attn_metadata, attn_metadata,
seq_lens, seq_lens,
query_lens, query_lens,
None, multi_modal_kwargs=multi_modal_kwargs,
) )
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
SamplingMetadata, Optional[torch.Tensor], ]: SamplingMetadata, Mapping[str, BatchedTensors]]:
multi_modal_input = None
# Prepare input tensors. # Prepare input tensors.
( (
input_tokens, input_tokens,
@ -274,7 +289,7 @@ class OpenVINOModelRunner:
attn_metadata, attn_metadata,
seq_lens, seq_lens,
query_lens, query_lens,
multi_modal_input, multi_modal_kwargs,
) = self._prepare_model_input(seq_group_metadata_list) ) = self._prepare_model_input(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
@ -290,7 +305,7 @@ class OpenVINOModelRunner:
input_positions, input_positions,
attn_metadata, attn_metadata,
sampling_metadata, sampling_metadata,
multi_modal_input, multi_modal_kwargs,
) )
@torch.inference_mode() @torch.inference_mode()
@ -304,7 +319,7 @@ class OpenVINOModelRunner:
input_positions, input_positions,
attn_metadata, attn_metadata,
sampling_metadata, sampling_metadata,
multi_modal_input, multi_modal_kwargs,
) = self.prepare_input_tensors(seq_group_metadata_list) ) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model model_executable = self.model
@ -313,9 +328,8 @@ class OpenVINOModelRunner:
"positions": input_positions, "positions": input_positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"attn_metadata": attn_metadata, "attn_metadata": attn_metadata,
**(multi_modal_kwargs or {}),
} }
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)

View File

@ -1,5 +1,5 @@
import time import time
from typing import List, Optional, Tuple from typing import List, Mapping, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@ -12,6 +12,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata, SamplerOutput, SequenceGroupMetadata,
SequenceOutput) SequenceOutput)
@ -66,6 +68,10 @@ class TPUModelRunner:
False, False,
) )
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
def load_model(self) -> None: def load_model(self) -> None:
self.device = self.device_config.device self.device = self.device_config.device
@ -193,12 +199,14 @@ class TPUModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
): ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
prompt_lens: List[int] = [] prompt_lens: List[int] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[List[int]] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
@ -224,6 +232,11 @@ class TPUModelRunner:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot) slot_mapping[-1].append(slot)
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
assert len(prompt_lens) > 0 assert len(prompt_lens) > 0
num_prefills = len(prompt_lens) num_prefills = len(prompt_lens)
num_prefill_tokens = sum(prompt_lens) num_prefill_tokens = sum(prompt_lens)
@ -261,17 +274,24 @@ class TPUModelRunner:
block_tables=None, block_tables=None,
context_lens=None, context_lens=None,
) )
return input_tokens, input_positions, attn_metadata, prompt_lens
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_kwargs)
def _prepare_decode( def _prepare_decode(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
): ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[List[int]] = []
context_lens: List[int] = [] context_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
batch_idx = 0 batch_idx = 0
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
@ -297,6 +317,11 @@ class TPUModelRunner:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append([slot]) slot_mapping.append([slot])
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
batch_size = _get_padded_batch_size(batch_idx) batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings input_tokens = input_tokens + [[0]] * num_paddings
@ -330,7 +355,12 @@ class TPUModelRunner:
block_tables=block_tables, block_tables=block_tables,
context_lens=context_lens, context_lens=context_lens,
) )
return input_tokens, input_positions, attn_metadata, input_lens
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, input_lens,
multi_modal_kwargs)
def _prepare_sample( def _prepare_sample(
self, self,
@ -483,6 +513,7 @@ class ModelWrapper(nn.Module):
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
input_lens: torch.Tensor, input_lens: torch.Tensor,
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
t: torch.Tensor, t: torch.Tensor,
p: torch.Tensor, p: torch.Tensor,
num_samples: int, num_samples: int,
@ -496,6 +527,8 @@ class ModelWrapper(nn.Module):
memory profiling at initialization. memory profiling at initialization.
attn_metadata: The Pallas attention metadata. attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size]. input_lens: The actual input lengths of shape [batch_size].
multi_modal_kwargs: Keyword arguments from multi-modal data to
pass to the model.
t: The sampling temperature of shape [batch_size]. t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size]. p: The top-p probability of shape [batch_size].
""" """
@ -540,6 +573,7 @@ class ModelWrapper(nn.Module):
position_ids, position_ids,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
**(multi_modal_kwargs or {}),
) )
hidden_states = hidden_states.flatten(0, 1) hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = self.model.compute_logits(hidden_states, sampling_metadata)

View File

@ -1,5 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Type, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -9,10 +10,13 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict from vllm.distributed import broadcast_tensor_dict
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
@ -44,7 +48,7 @@ class ModelInputForXPU(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_input: Optional[Dict[str, torch.Tensor]] = None multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]: self) -> Dict[str, Union[int, torch.Tensor]]:
@ -116,6 +120,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self.block_size, self.block_size,
) )
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
@ -156,12 +164,26 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# To exercise the worst scenario for GPU memory consumption, # To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number # the number of seqs (batch_size) is chosen to maximize the number
# of images processed. # of images processed.
model_config = self.model_config
vlm_config = self.vision_language_config
if vlm_config:
max_num_seqs = min(
max_num_seqs,
int(max_num_batched_tokens / vlm_config.image_feature_size))
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
seq_data = SequenceData([0] * seq_len) seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
dummy_multi_modal_data = None .dummy_data_for_profiling(model_config, seq_len)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
is_prompt=True, is_prompt=True,
@ -194,7 +216,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForXPU: ) -> ModelInputForXPU:
multi_modal_input = None multi_modal_kwargs = None
if self.is_driver_worker: if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
@ -202,7 +224,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# Prepare input tensors. # Prepare input tensors.
if is_prompt: if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens, (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list) ) = self._prepare_prompt(seq_group_metadata_list)
else: else:
(input_tokens, input_positions, (input_tokens, input_positions,
@ -223,6 +245,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
"input_positions": input_positions, "input_positions": input_positions,
"selected_token_indices": "selected_token_indices":
sampling_metadata.selected_token_indices, sampling_metadata.selected_token_indices,
"multi_modal_kwargs": multi_modal_kwargs,
} }
metadata_dict.update(attn_metadata.asdict_zerocopy()) metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0) broadcast_tensor_dict(metadata_dict, src=0)
@ -232,6 +255,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
input_positions = metadata_dict.pop("input_positions") input_positions = metadata_dict.pop("input_positions")
selected_token_indices = metadata_dict.pop( selected_token_indices = metadata_dict.pop(
"selected_token_indices") "selected_token_indices")
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
attn_metadata = self.attn_backend.make_metadata(**metadata_dict) attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
seq_groups=None, seq_groups=None,
@ -244,7 +268,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
input_positions=input_positions, input_positions=input_positions,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
multi_modal_input=multi_modal_input) multi_modal_kwargs=multi_modal_kwargs)
def _prepare_decode( def _prepare_decode(
self, self,
@ -350,10 +374,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
"positions": model_input.input_positions, "positions": model_input.input_positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"attn_metadata": model_input.attn_metadata, "attn_metadata": model_input.attn_metadata,
**(model_input.multi_modal_kwargs or {}),
} }
if self.vision_language_config:
execute_model_kwargs.update(
{"image_input": model_input.multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)
@ -376,13 +398,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Optional[torch.Tensor]]: Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = [] multi_modal_inputs_list: List[MultiModalInputs] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
@ -403,9 +425,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len))) input_positions.extend(list(range(computed_len, seq_len)))
if seq_group_metadata.multi_modal_data: mm_data = seq_group_metadata.multi_modal_data
multi_modal_input_list.append( if mm_data:
seq_group_metadata.multi_modal_data.data) mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
@ -435,15 +458,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None
num_prompt_tokens = len(input_tokens) num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens, input_tokens = torch.tensor(input_tokens,
@ -475,5 +489,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
num_decode_tokens=0, num_decode_tokens=0,
block_tables=torch.tensor([], device=self.device, dtype=torch.int), block_tables=torch.tensor([], device=self.device, dtype=torch.int),
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, seq_lens, return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input) multi_modal_kwargs)