[Core] Dynamic image size support for VLMs (#5276)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: ywang96 <ywang@roblox.com> Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
parent
482045ee77
commit
9831aec49f
@ -8,7 +8,7 @@ Input Processing
|
||||
vLLM provides a mechanism for defining input processors for each model so that the inputs are processed
|
||||
in :class:`~vllm.LLMEngine` before they are passed to model executors.
|
||||
|
||||
Currently, this mechanism is only utilized in **multi-modal models** for preprocessing multi-modal input
|
||||
Currently, this mechanism is only utilized in :ref:`multi-modal models <multi_modality>` for preprocessing multi-modal input
|
||||
data in addition to input prompt, but it can be extended to text-only language models when needed.
|
||||
|
||||
Guides
|
||||
|
124
docs/source/dev/multimodal/adding_multimodal_model.rst
Normal file
124
docs/source/dev/multimodal/adding_multimodal_model.rst
Normal file
@ -0,0 +1,124 @@
|
||||
.. _adding_a_new_multimodal_model:
|
||||
|
||||
Adding a New Multimodal Model
|
||||
=============================
|
||||
|
||||
This document provides a high-level guide on integrating a :ref:`multi-modal model <multi_modality>` into vLLM.
|
||||
|
||||
.. note::
|
||||
The complexity of adding a new model depends heavily on the model's architecture.
|
||||
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
|
||||
However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.
|
||||
|
||||
.. tip::
|
||||
If you are encountering issues while integrating your model into vLLM, feel free to open an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ repository.
|
||||
We will be happy to help you out!
|
||||
|
||||
|
||||
1. Set up the base vLLM model
|
||||
-----------------------------
|
||||
|
||||
As usual, follow :ref:`these steps <adding_a_new_model>` to implement the model in vLLM, but note the following:
|
||||
|
||||
- You should additionally implement the :class:`~vllm.model_executor.models.interfaces.SupportsVision` interface.
|
||||
|
||||
.. code-block:: diff
|
||||
|
||||
+ from vllm.model_executor.models.interfaces import SupportsVision
|
||||
|
||||
- class YourModelForImage2Seq(nn.Module):
|
||||
+ class YourModelForImage2Seq(nn.Module, SupportsVision):
|
||||
|
||||
.. note::
|
||||
The model class does not have to be named :code:`*ForCausalLM`.
|
||||
Check out `the HuggingFace Transformers documentation <https://huggingface.co/docs/transformers/model_doc/auto#multimodal>`__ for some examples.
|
||||
|
||||
- While implementing the :meth:`~torch.nn.Module.forward` method, reserve a keyword parameter
|
||||
for each input tensor that corresponds to a multi-modal input, as shown in the following example:
|
||||
|
||||
.. code-block:: diff
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
+ pixel_values: torch.Tensor,
|
||||
) -> SamplerOutput:
|
||||
|
||||
|
||||
2. Register input mappers
|
||||
-------------------------
|
||||
|
||||
For each modality type to support, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
|
||||
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`.
|
||||
|
||||
.. code-block:: diff
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsVision
|
||||
+ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
+ @MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
|
||||
+ @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
|
||||
class YourModelForImage2Seq(nn.Module, SupportsVision):
|
||||
|
||||
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
|
||||
|
||||
.. seealso::
|
||||
:ref:`input_processing_pipeline`
|
||||
|
||||
|
||||
3. (Optional) Register dummy data
|
||||
---------------------------------
|
||||
|
||||
During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
|
||||
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.
|
||||
|
||||
.. code-block:: diff
|
||||
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.model_executor.models.interfaces import SupportsVision
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
|
||||
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
|
||||
class YourModelForImage2Seq(nn.Module, SupportsVision):
|
||||
|
||||
Here are some examples:
|
||||
|
||||
- Image inputs (static feature size): `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
|
||||
- Image inputs (dynamic feature size): `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__
|
||||
|
||||
.. seealso::
|
||||
:ref:`input_processing_pipeline`
|
||||
|
||||
|
||||
4. (Optional) Register input processor
|
||||
--------------------------------------
|
||||
|
||||
Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor.
|
||||
This is often due to the fact that unlike implementations in HuggingFace Transformers, the reshaping and/or expansion of multi-modal embeddings needs to take place outside model's :meth:`~torch.nn.Module.forward` call.
|
||||
You can register input processors via :meth:`INPUT_REGISTRY.register_input_processor <vllm.inputs.registry.InputRegistry.register_input_processor>`.
|
||||
|
||||
.. code-block:: diff
|
||||
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.model_executor.models.interfaces import SupportsVision
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
|
||||
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
|
||||
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
|
||||
class YourModelForImage2Seq(nn.Module, SupportsVision):
|
||||
|
||||
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
|
||||
Here are some examples:
|
||||
|
||||
- Insert static number of image tokens: `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
|
||||
- Insert dynamic number of image tokens: `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__
|
||||
|
||||
.. seealso::
|
||||
:ref:`input_processing_pipeline`
|
@ -1,3 +1,5 @@
|
||||
.. _multi_modality:
|
||||
|
||||
Multi-Modality
|
||||
==============
|
||||
|
||||
@ -8,12 +10,18 @@ vLLM provides experimental support for multi-modal models through the :mod:`vllm
|
||||
:class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data``
|
||||
which allows you to pass in multi-modal input alongside text and token prompts.
|
||||
|
||||
By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model,
|
||||
you must decorate the model class with :meth:`InputRegistry.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`,
|
||||
as well as :meth:`MULTIMODAL_REGISTRY.register_input_mapper <MultiModalRegistry.register_input_mapper>` for each modality type to support.
|
||||
By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, please follow :ref:`the guide for adding a new multimodal model. <adding_a_new_multimodal_model>`.
|
||||
|
||||
# TODO: Add more instructions on how to do that once embeddings is in.
|
||||
|
||||
Guides
|
||||
++++++
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
adding_multimodal_model
|
||||
|
||||
Module Contents
|
||||
+++++++++++++++
|
||||
|
||||
@ -35,6 +43,10 @@ Base Classes
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: vllm.multimodal.MultiModalInputs
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: vllm.multimodal.MultiModalPlugin
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
@ -23,7 +23,6 @@ The following :ref:`engine arguments <engine_args>` are specific to VLMs:
|
||||
Currently, the support for vision language models on vLLM has the following limitations:
|
||||
|
||||
* Only single image input is supported per text prompt.
|
||||
* Dynamic ``image_input_shape`` is not supported: the input image will be resized to the static ``image_input_shape``. This means our LLaVA-NeXT output may not exactly match the huggingface implementation.
|
||||
|
||||
We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.
|
||||
|
||||
@ -42,12 +41,17 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
|
||||
)
|
||||
|
||||
.. important::
|
||||
Currently, you have to specify ``image_feature_size`` to support memory profiling.
|
||||
To avoid OOM during runtime, you should set this to the maximum value supported by the model.
|
||||
The calculation of feature size is specific to the model. For more details, please refer to
|
||||
the function :code:`get_<model_name>_image_feature_size` inside the corresponding model file.
|
||||
|
||||
We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration.
|
||||
|
||||
|
||||
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
|
||||
|
||||
* ``prompt``: The prompt should have a number of ``<image>`` tokens equal to ``image_feature_size``.
|
||||
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
|
||||
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
|
||||
|
||||
.. note::
|
||||
@ -57,8 +61,8 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt = "<image>" * 576 + (
|
||||
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||
# Refer to the HuggingFace repo for the correct format to use
|
||||
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
|
||||
|
||||
# Load the image using PIL.Image
|
||||
image = ...
|
||||
@ -74,8 +78,6 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS
|
||||
|
||||
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_.
|
||||
|
||||
.. important::
|
||||
We will remove the need to format image tokens in a future release. Afterwards, the input text will follow the same format as that for the original HuggingFace model.
|
||||
|
||||
Online OpenAI Vision API Compatible Inference
|
||||
----------------------------------------------
|
||||
@ -103,6 +105,11 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
|
||||
--chat-template template_llava.jinja
|
||||
|
||||
.. important::
|
||||
Currently, you have to specify ``image_feature_size`` to support memory profiling.
|
||||
To avoid OOM during runtime, you should set this to the maximum value supported by the model.
|
||||
The calculation of feature size is specific to the model. For more details, please refer to
|
||||
the function :code:`get_<model_name>_image_feature_size` inside the corresponding model file.
|
||||
|
||||
We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration.
|
||||
|
||||
To consume the server, you can use the OpenAI client like in the example below:
|
||||
@ -121,6 +128,8 @@ To consume the server, you can use the OpenAI client like in the example below:
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
# NOTE: The prompt formatting with the image token `<image>` is not needed
|
||||
# since the prompt will be processed automatically by the API server.
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
@ -144,5 +153,4 @@ A full code example can be found in `examples/openai_vision_api_client.py <https
|
||||
export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
|
||||
|
||||
.. note::
|
||||
The prompt formatting with the image token ``<image>`` is not needed when serving VLMs with the API server since the prompt will be
|
||||
processed automatically by the server.
|
||||
There is no need to format the prompt in the API request since it will be handled by the server.
|
||||
|
@ -17,8 +17,7 @@ def run_llava():
|
||||
image_feature_size=576,
|
||||
)
|
||||
|
||||
prompt = "<image>" * 576 + (
|
||||
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
|
||||
|
||||
image = Image.open("images/stop_sign.jpg")
|
||||
|
||||
|
@ -5,22 +5,17 @@ from PIL import Image
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Dynamic image input is currently not supported and therefore
|
||||
# a fixed image input shape and its corresponding feature size is required.
|
||||
# See https://github.com/vllm-project/vllm/pull/4199 for the complete
|
||||
# configuration matrix.
|
||||
|
||||
|
||||
def run_llava_next():
|
||||
llm = LLM(
|
||||
model="llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
image_token_id=32000,
|
||||
image_input_shape="1,3,336,336",
|
||||
image_feature_size=1176,
|
||||
# Use the maximum possible value for memory profiling
|
||||
image_feature_size=2928,
|
||||
)
|
||||
|
||||
prompt = "[INST] " + "<image>" * 1176 + (
|
||||
"\nWhat is shown in this image? [/INST]")
|
||||
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
|
||||
url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
|
||||
image = Image.open(BytesIO(requests.get(url).content))
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
|
@ -5,6 +5,9 @@ from PIL import Image
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
|
||||
# You can use `.buildkite/download-images.sh` to download them
|
||||
|
||||
|
||||
def run_phi3v():
|
||||
model_path = "microsoft/Phi-3-vision-128k-instruct"
|
||||
@ -18,7 +21,8 @@ def run_phi3v():
|
||||
trust_remote_code=True,
|
||||
image_token_id=32044,
|
||||
image_input_shape="1,3,1008,1344",
|
||||
image_feature_size=1921,
|
||||
# Use the maximum possible value for memory profiling
|
||||
image_feature_size=2653,
|
||||
max_num_seqs=5,
|
||||
)
|
||||
|
||||
@ -26,8 +30,6 @@ def run_phi3v():
|
||||
|
||||
# single-image prompt
|
||||
prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501
|
||||
prompt = prompt.replace("<|image_1|>", "<|image|>" * 1921 + "<s>")
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=64)
|
||||
|
||||
outputs = llm.generate(
|
||||
|
@ -1,12 +1,13 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
from collections import UserList
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple,
|
||||
TypedDict, TypeVar)
|
||||
from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict,
|
||||
TypeVar)
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -22,13 +23,10 @@ from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import cuda_device_count_stateless, is_cpu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# it will call torch.cuda.device_count()
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_TEST_DIR = os.path.dirname(__file__)
|
||||
@ -47,30 +45,42 @@ def _read_prompts(filename: str) -> List[str]:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ImageAsset:
|
||||
name: Literal["stop_sign", "cherry_blossom"]
|
||||
name: Literal["stop_sign", "cherry_blossom", "boardwalk"]
|
||||
|
||||
@cached_property
|
||||
def pil_image(self) -> Image.Image:
|
||||
if self.name == "boardwalk":
|
||||
return fetch_image(
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
)
|
||||
|
||||
return Image.open(_IMAGE_DIR / f"{self.name}.jpg")
|
||||
|
||||
def for_hf(self) -> Image.Image:
|
||||
return self.pil_image
|
||||
|
||||
def for_vllm(self) -> Dict[str, Any]:
|
||||
return {"image": self.pil_image}
|
||||
|
||||
|
||||
class _ImageAssetPrompts(TypedDict):
|
||||
stop_sign: str
|
||||
cherry_blossom: str
|
||||
boardwalk: str
|
||||
|
||||
|
||||
class _ImageAssets(UserList):
|
||||
if sys.version_info < (3, 9):
|
||||
# UserList cannot be subscripted
|
||||
class _ImageAssetsBase(UserList):
|
||||
pass
|
||||
else:
|
||||
|
||||
class _ImageAssetsBase(UserList[ImageAsset]):
|
||||
pass
|
||||
|
||||
|
||||
class _ImageAssets(_ImageAssetsBase):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
[ImageAsset("stop_sign"),
|
||||
ImageAsset("cherry_blossom")])
|
||||
super().__init__([
|
||||
ImageAsset("stop_sign"),
|
||||
ImageAsset("cherry_blossom"),
|
||||
ImageAsset("boardwalk")
|
||||
])
|
||||
|
||||
def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
|
||||
"""
|
||||
@ -79,7 +89,10 @@ class _ImageAssets(UserList):
|
||||
The order of the returned prompts matches the order of the
|
||||
assets when iterating through this object.
|
||||
"""
|
||||
return [prompts["stop_sign"], prompts["cherry_blossom"]]
|
||||
return [
|
||||
prompts["stop_sign"], prompts["cherry_blossom"],
|
||||
prompts["boardwalk"]
|
||||
]
|
||||
|
||||
|
||||
IMAGE_ASSETS = _ImageAssets()
|
||||
@ -220,7 +233,7 @@ class HfRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
if images:
|
||||
assert len(prompts) == len(images)
|
||||
@ -255,7 +268,7 @@ class HfRunner:
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
outputs = self.generate(prompts,
|
||||
do_sample=False,
|
||||
@ -291,19 +304,30 @@ class HfRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[List[torch.Tensor]]:
|
||||
all_logprobs = []
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
all_logprobs: List[List[torch.Tensor]] = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
processor_kwargs: Dict[str, Any] = {
|
||||
"text": prompt,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
if images is not None and images[i] is not None:
|
||||
processor_kwargs["images"] = images[i]
|
||||
|
||||
inputs = self.processor(**processor_kwargs)
|
||||
|
||||
output = self.model.generate(
|
||||
self.wrap_device(input_ids),
|
||||
**self.wrap_device(inputs),
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
output_hidden_states=True,
|
||||
return_dict_in_generate=True,
|
||||
**kwargs,
|
||||
)
|
||||
seq_logprobs = []
|
||||
seq_logprobs: List[torch.Tensor] = []
|
||||
for hidden_states in output.hidden_states:
|
||||
last_hidden_states = hidden_states[-1][0]
|
||||
logits = torch.matmul(
|
||||
@ -323,20 +347,32 @@ class HfRunner:
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
|
||||
all_logprobs: List[List[Dict[int, float]]] = []
|
||||
all_output_ids: List[List[int]] = []
|
||||
all_output_strs: List[str] = []
|
||||
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
for i, prompt in enumerate(prompts):
|
||||
processor_kwargs: Dict[str, Any] = {
|
||||
"text": prompt,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
if images is not None and images[i] is not None:
|
||||
processor_kwargs["images"] = images[i]
|
||||
|
||||
inputs = self.processor(**processor_kwargs)
|
||||
input_ids = inputs.input_ids
|
||||
|
||||
output = self.model.generate(
|
||||
self.wrap_device(input_ids),
|
||||
**self.wrap_device(inputs),
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
output_hidden_states=True,
|
||||
return_dict_in_generate=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
seq_logprobs: List[torch.Tensor] = []
|
||||
@ -431,7 +467,7 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[List["MultiModalDataDict"]] = None,
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
@ -439,7 +475,7 @@ class VllmRunner:
|
||||
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
inputs[i]["multi_modal_data"] = image
|
||||
inputs[i]["multi_modal_data"] = {"image": image}
|
||||
|
||||
req_outputs = self.model.generate(inputs,
|
||||
sampling_params=sampling_params)
|
||||
@ -462,10 +498,19 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
assert sampling_params.logprobs is not None
|
||||
|
||||
req_outputs = self.model.generate(prompts,
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
|
||||
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
inputs[i]["multi_modal_data"] = {"image": image}
|
||||
|
||||
req_outputs = self.model.generate(inputs,
|
||||
sampling_params=sampling_params)
|
||||
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
|
||||
for req_output in req_outputs:
|
||||
@ -480,7 +525,7 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
images: Optional[List["MultiModalDataDict"]] = None,
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs = self.generate(prompts, greedy_params, images=images)
|
||||
@ -492,11 +537,14 @@ class VllmRunner:
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_logprobs)
|
||||
outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)
|
||||
outputs = self.generate_w_logprobs(prompts,
|
||||
greedy_logprobs_params,
|
||||
images=images)
|
||||
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
|
@ -30,9 +30,10 @@ else:
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [2])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets,
|
||||
tensor_parallel_size: int, dtype: str,
|
||||
max_tokens: int) -> None:
|
||||
tensor_parallel_size: int, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||
pytest.skip(
|
||||
f"Need at least {tensor_parallel_size} GPUs to run the test.")
|
||||
@ -44,8 +45,10 @@ def test_models(hf_runner, vllm_runner, image_assets,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model_and_config=model_and_vl_config[0],
|
||||
size_factors=[1.0],
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
)
|
||||
|
@ -4,18 +4,21 @@ import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from .utils import check_outputs_equal
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
# The image token is placed before "user" on purpose so that the test can pass
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
|
||||
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
|
||||
"cherry_blossom":
|
||||
"<image>\nUSER: What is the season?\nASSISTANT:",
|
||||
"USER: <image>\nWhat is the season?\nASSISTANT:",
|
||||
"boardwalk":
|
||||
"USER: <image>\nWhat's in this image?\nASSISTANT:",
|
||||
})
|
||||
|
||||
|
||||
@ -37,27 +40,34 @@ model_and_vl_config = [
|
||||
]
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]],
|
||||
vlm_config: VisionLanguageConfig, model_id: str):
|
||||
"""Sanitize vllm output to be comparable with hf output.
|
||||
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
|
||||
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
|
||||
It also reduces `output_str` from "<image><image>bla" to "bla".
|
||||
"""
|
||||
output_ids, output_str = vllm_output
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
image_token_id = vlm_config.image_token_id
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
image_token_str = tokenizer.decode(image_token_id)
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
hf_output_ids = [
|
||||
token_id for idx, token_id in enumerate(output_ids)
|
||||
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
||||
]
|
||||
|
||||
hf_output_str = output_str \
|
||||
.replace(image_token_str * vlm_config.image_feature_size, "")
|
||||
assert hf_output_str[0] == " "
|
||||
hf_output_str = hf_output_str[1:]
|
||||
if hf_output_ids[-1] == eos_token_id:
|
||||
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
|
||||
|
||||
return hf_output_ids, hf_output_str
|
||||
return hf_output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
def run_test(
|
||||
@ -66,8 +76,10 @@ def run_test(
|
||||
image_assets: _ImageAssets,
|
||||
model_and_config: Tuple[str, VisionLanguageConfig],
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
@ -81,61 +93,85 @@ def run_test(
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
model_id, vlm_config = model_and_config
|
||||
hf_images = [asset.for_hf() for asset in image_assets]
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model_id,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
**vlm_config.as_cli_args_dict()) as vllm_model:
|
||||
|
||||
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()`
|
||||
# we must put it inside the vllm_runner context manager
|
||||
# i.e. after creating vLLM instance.
|
||||
vllm_images = [asset.for_vllm() for asset in image_assets]
|
||||
|
||||
vllm_image_prompts = [
|
||||
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
|
||||
for p in HF_IMAGE_PROMPTS
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
||||
max_tokens,
|
||||
images=vllm_images)
|
||||
|
||||
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
|
||||
max_tokens,
|
||||
images=hf_images)
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
check_outputs_equal(
|
||||
hf_outputs,
|
||||
[
|
||||
vllm_to_hf_output(vllm_output, vlm_config, model_id)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
vllm_outputs_per_image):
|
||||
# TODO: Check whether using original CLIPVisionModel can improve
|
||||
# consistency against HF
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, vlm_config, model_id)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
dtype: str, max_tokens: int) -> None:
|
||||
size_factors, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model_and_config,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
@ -1,12 +1,15 @@
|
||||
from typing import List, Tuple
|
||||
import re
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ..conftest import IMAGE_ASSETS
|
||||
from .utils import check_outputs_equal
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
@ -15,21 +18,20 @@ _PREFACE = (
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's "
|
||||
"questions.")
|
||||
|
||||
# The image token is placed before "user" on purpose so that the test can pass
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
f"{_PREFACE} <image>\nUSER: What's the content of the image?\nASSISTANT:",
|
||||
f"{_PREFACE} USER: <image>\nWhat's the content of the image? ASSISTANT:",
|
||||
"cherry_blossom":
|
||||
f"{_PREFACE} <image>\nUSER: What is the season?\nASSISTANT:",
|
||||
f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
|
||||
"boardwalk":
|
||||
f"{_PREFACE} USER: <image>\nWhat's in this image? ASSISTANT:",
|
||||
})
|
||||
|
||||
|
||||
def iter_llava_next_configs(model_name: str):
|
||||
# Need to use the max possible feature size for profile_run
|
||||
image_hw_to_feature_size = {
|
||||
(336, 336): 1176,
|
||||
(672, 672): 2928,
|
||||
(1344, 336): 1944,
|
||||
(336, 1344): 1890,
|
||||
(336, 336): 2928,
|
||||
}
|
||||
|
||||
for (h, w), f in image_hw_to_feature_size.items():
|
||||
@ -47,37 +49,55 @@ model_and_vl_config = [
|
||||
]
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]],
|
||||
vlm_config: VisionLanguageConfig, model_id: str):
|
||||
"""Sanitize vllm output to be comparable with hf output.
|
||||
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
|
||||
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
|
||||
It also reduces `output_str` from "<image><image>bla" to "bla".
|
||||
"""
|
||||
output_ids, output_str = vllm_output
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
image_token_id = vlm_config.image_token_id
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
image_token_str = tokenizer.decode(image_token_id)
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
hf_output_ids = [
|
||||
token_id for idx, token_id in enumerate(output_ids)
|
||||
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
||||
]
|
||||
hf_output_str = output_str \
|
||||
.replace(image_token_str * vlm_config.image_feature_size, " ")
|
||||
|
||||
return hf_output_ids, hf_output_str
|
||||
hf_output_str = re.sub(fr"({image_token_str})+", "", output_str)
|
||||
assert hf_output_str[0] == " "
|
||||
hf_output_str = hf_output_str[1:]
|
||||
if hf_output_ids[-1] == eos_token_id:
|
||||
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
|
||||
|
||||
return hf_output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Inconsistent image processor being used due to lack "
|
||||
"of support for dynamic image token replacement")
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
dtype: str, max_tokens: int) -> None:
|
||||
size_factors, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
@ -88,37 +108,46 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
model_id, vlm_config = model_and_config
|
||||
hf_images = [asset.for_hf() for asset in image_assets]
|
||||
vllm_images = [asset.for_vllm() for asset in image_assets]
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model_id,
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
enforce_eager=True,
|
||||
**vlm_config.as_cli_args_dict()) as vllm_model:
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
|
||||
max_tokens,
|
||||
images=hf_images)
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
vllm_image_prompts = [
|
||||
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
|
||||
for p in HF_IMAGE_PROMPTS
|
||||
]
|
||||
|
||||
with vllm_runner(
|
||||
model_id,
|
||||
dtype=dtype,
|
||||
# should be greater than image_feature_size
|
||||
max_model_len=4096,
|
||||
enforce_eager=True,
|
||||
**vlm_config.as_cli_args_dict(),
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
||||
max_tokens,
|
||||
images=vllm_images)
|
||||
|
||||
check_outputs_equal(
|
||||
hf_outputs,
|
||||
[
|
||||
vllm_to_hf_output(vllm_output, vlm_config, model_id)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
vllm_outputs_per_image):
|
||||
# TODO: Check whether using original CLIPVisionModel can improve
|
||||
# consistency against HF
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, vlm_config, model_id)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
@ -1,29 +1,33 @@
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from .utils import check_outputs_equal
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
# The image token is placed before "user" on purpose so that the test can pass
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
|
||||
"cherry_blossom":
|
||||
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", # noqa: E501
|
||||
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
|
||||
"boardwalk":
|
||||
"<|user|>\n<|image_1|>\nWhat's in this image?<|end|>\n<|assistant|>\n",
|
||||
})
|
||||
|
||||
|
||||
def iter_phi3v_configs(model_name: str):
|
||||
# Need to use the max possible feature size for profile_run
|
||||
image_hw_to_feature_size = {
|
||||
(1008, 1344): 1921,
|
||||
(2016, 2688): 1933,
|
||||
(1008, 1344): 2653,
|
||||
}
|
||||
|
||||
for (h, w), f in image_hw_to_feature_size.items():
|
||||
@ -39,29 +43,29 @@ model_and_vl_config = [
|
||||
]
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]],
|
||||
vlm_config: VisionLanguageConfig, model_id: str):
|
||||
"""Sanitize vllm output to be comparable with hf output.
|
||||
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
|
||||
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
|
||||
It also reduces `output_str` from "<image><image>bla" to "bla".
|
||||
"""
|
||||
output_ids, output_str = vllm_output
|
||||
image_token_id = vlm_config.image_token_id
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
image_token_str = tokenizer.decode(image_token_id)
|
||||
output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str)
|
||||
assert output_str_without_image[0] == " "
|
||||
output_str_without_image = output_str_without_image[1:]
|
||||
|
||||
hf_output_ids = [
|
||||
token_id if token_id != image_token_id else 0
|
||||
for idx, token_id in enumerate(output_ids)
|
||||
]
|
||||
hf_output_str = output_str \
|
||||
.replace(image_token_str * vlm_config.image_feature_size, "") \
|
||||
.replace("<s>", " ").replace("<|user|>", "") \
|
||||
hf_output_str = output_str_without_image.replace("<|user|>", "") \
|
||||
.replace("<|end|>\n<|assistant|>", " ")
|
||||
|
||||
return hf_output_ids, hf_output_str
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
hf_output_ids = tokenizer.encode(output_str_without_image)
|
||||
assert hf_output_ids[0] == 1
|
||||
hf_output_ids = hf_output_ids[1:]
|
||||
|
||||
return hf_output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
target_dtype = "half"
|
||||
@ -75,8 +79,10 @@ def run_test(
|
||||
image_assets: _ImageAssets,
|
||||
model_and_config: Tuple[str, VisionLanguageConfig],
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
@ -90,73 +96,91 @@ def run_test(
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
model_id, vlm_config = model_and_config
|
||||
hf_images = [asset.for_hf() for asset in image_assets]
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model_id,
|
||||
max_model_len=2048,
|
||||
max_model_len=4096,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=True,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
**vlm_config.as_cli_args_dict()) as vllm_model:
|
||||
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()`
|
||||
# we must put it inside the vllm_runner context manager
|
||||
# i.e. after creating vLLM instance.
|
||||
|
||||
vllm_images = [asset.for_vllm() for asset in image_assets]
|
||||
|
||||
vllm_image_prompts = [
|
||||
p.replace("<|image_1|>",
|
||||
"<|image|>" * vlm_config.image_feature_size + "<s>")
|
||||
for p in HF_IMAGE_PROMPTS
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=vllm_images)
|
||||
for prompts, vllm_images in inputs_per_image
|
||||
]
|
||||
|
||||
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
||||
max_tokens,
|
||||
images=vllm_images)
|
||||
|
||||
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
|
||||
hf_model_kwargs = {"_attn_implementation": "eager"}
|
||||
with hf_runner(model_id, dtype=dtype,
|
||||
model_kwargs=hf_model_kwargs) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(
|
||||
HF_IMAGE_PROMPTS,
|
||||
max_tokens,
|
||||
images=hf_images,
|
||||
eos_token_id=hf_model.processor.tokenizer.eos_token_id)
|
||||
eos_token_id = hf_model.processor.tokenizer.eos_token_id
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=hf_images,
|
||||
eos_token_id=eos_token_id)
|
||||
for prompts, hf_images in inputs_per_image
|
||||
]
|
||||
|
||||
check_outputs_equal(
|
||||
hf_outputs,
|
||||
[
|
||||
vllm_to_hf_output(vllm_output, vlm_config, model_id)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
vllm_outputs_per_image):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, vlm_config, model_id)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
# Since we use _attn_implementation="eager" for hf_runner, here is
|
||||
# numeric difference for longer context and test can't pass
|
||||
@pytest.mark.xfail(
|
||||
reason="Inconsistent image processor being used due to lack "
|
||||
"of support for dynamic image token replacement")
|
||||
# Since we use _attn_implementation="eager" for hf_runner, there is more
|
||||
# significant numerical difference. The basic `logprobs=5` fails to pass.
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
dtype: str, max_tokens: int) -> None:
|
||||
size_factors, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model_and_config,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
@ -1,11 +1,18 @@
|
||||
from typing import Dict, List, Tuple
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
TokensText = Tuple[List[int], str]
|
||||
|
||||
|
||||
def check_outputs_equal(outputs_0_lst: List[TokensText],
|
||||
outputs_1_lst: List[TokensText], name_0: str,
|
||||
name_1: str):
|
||||
def check_outputs_equal(
|
||||
*,
|
||||
outputs_0_lst: Sequence[TokensText],
|
||||
outputs_1_lst: Sequence[TokensText],
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
):
|
||||
"""
|
||||
Compare the two sequences generated by different models,
|
||||
which should be equal.
|
||||
@ -18,20 +25,28 @@ def check_outputs_equal(outputs_0_lst: List[TokensText],
|
||||
output_ids_0, output_str_0 = outputs_0
|
||||
output_ids_1, output_str_1 = outputs_1
|
||||
|
||||
assert output_str_0 == output_str_1, (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
# The text and token outputs should exactly match
|
||||
fail_msg = (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
|
||||
assert output_str_0 == output_str_1, fail_msg
|
||||
assert output_ids_0 == output_ids_1, fail_msg
|
||||
|
||||
|
||||
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
|
||||
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
|
||||
float]],
|
||||
SampleLogprobs]]]
|
||||
|
||||
|
||||
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
|
||||
outputs_1_lst: List[TokensTextLogprobs], name_0: str,
|
||||
name_1: str):
|
||||
def check_logprobs_close(
|
||||
*,
|
||||
outputs_0_lst: Sequence[TokensTextLogprobs],
|
||||
outputs_1_lst: Sequence[TokensTextLogprobs],
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
warn_on_mismatch: bool = True,
|
||||
):
|
||||
"""
|
||||
Compare the logprobs of two sequences generated by different models,
|
||||
which should be similar but not necessarily equal.
|
||||
@ -45,21 +60,52 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
|
||||
output_ids_0, output_str_0, logprobs_0 = outputs_0
|
||||
output_ids_1, output_str_1, logprobs_1 = outputs_1
|
||||
|
||||
if logprobs_0 is None:
|
||||
logprobs_0 = [None] * len(output_ids_0)
|
||||
if logprobs_1 is None:
|
||||
logprobs_1 = [None] * len(output_ids_1)
|
||||
|
||||
# Loop through generated tokens.
|
||||
for idx, (output_id_0,
|
||||
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||||
|
||||
# If generated tokens don't match, then
|
||||
if output_id_0 != output_id_1:
|
||||
logprobs_elem_0 = logprobs_0[idx]
|
||||
logprobs_elem_1 = logprobs_1[idx]
|
||||
|
||||
# Each predicted token must be in top N logprobs of the other
|
||||
assert output_id_0 in logprobs_1[idx], (
|
||||
fail_msg = (
|
||||
f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
assert output_id_1 in logprobs_0[idx], (
|
||||
f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
|
||||
f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
|
||||
|
||||
assert logprobs_elem_0 is not None, fail_msg
|
||||
assert logprobs_elem_1 is not None, fail_msg
|
||||
assert output_id_0 in logprobs_elem_1, fail_msg
|
||||
assert output_id_1 in logprobs_elem_0, fail_msg
|
||||
|
||||
if warn_on_mismatch:
|
||||
with warnings.catch_warnings():
|
||||
# This ensures that repeated warnings are shown
|
||||
# in the output, not just the first occurrence
|
||||
warnings.simplefilter("always")
|
||||
|
||||
warnings.warn(fail_msg, stacklevel=2)
|
||||
|
||||
# Break out since sequences will now diverge.
|
||||
break
|
||||
else:
|
||||
if output_str_0 != output_str_1 and warn_on_mismatch:
|
||||
# The token outputs exactly match,
|
||||
# so the text outputs should exactly match as well
|
||||
fail_msg = (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
# This ensures that repeated warnings are shown
|
||||
# in the output, not just the first occurrence
|
||||
warnings.simplefilter("always")
|
||||
|
||||
warnings.warn(fail_msg, stacklevel=2)
|
||||
|
@ -4,12 +4,12 @@ from transformers import CLIPImageProcessor, LlavaNextImageProcessor
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from ..conftest import _STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
||||
def test_clip_image_processor(image_assets, dtype):
|
||||
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
|
||||
def test_clip_image_processor(image_assets, dtype, size_factor):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
|
||||
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
|
||||
@ -26,13 +26,15 @@ def test_clip_image_processor(image_assets, dtype):
|
||||
)
|
||||
|
||||
for asset in image_assets:
|
||||
image = rescale_image_size(asset.pil_image, size_factor)
|
||||
|
||||
hf_result = hf_processor.preprocess(
|
||||
asset.pil_image,
|
||||
image,
|
||||
return_tensors="pt",
|
||||
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
|
||||
)
|
||||
vllm_result = MULTIMODAL_REGISTRY.map_input(
|
||||
model_config,
|
||||
{"image": asset.pil_image},
|
||||
{"image": image},
|
||||
)
|
||||
|
||||
assert hf_result.keys() == vllm_result.keys()
|
||||
@ -44,12 +46,10 @@ def test_clip_image_processor(image_assets, dtype):
|
||||
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Inconsistent image processor being used due to lack "
|
||||
"of support for dynamic image token replacement")
|
||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
||||
def test_llava_next_image_processor(image_assets, dtype):
|
||||
MODEL_NAME = "llava-hf/llava-v1.6-34b-hf"
|
||||
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
|
||||
def test_llava_next_image_processor(image_assets, dtype, size_factor):
|
||||
MODEL_NAME = "llava-hf/llava-v1.6-vicuna-7b-hf"
|
||||
|
||||
hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
|
||||
assert isinstance(hf_processor, LlavaNextImageProcessor)
|
||||
@ -65,13 +65,15 @@ def test_llava_next_image_processor(image_assets, dtype):
|
||||
)
|
||||
|
||||
for asset in image_assets:
|
||||
image = rescale_image_size(asset.pil_image, size_factor)
|
||||
|
||||
hf_result = hf_processor.preprocess(
|
||||
asset.pil_image,
|
||||
image,
|
||||
return_tensors="pt",
|
||||
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
|
||||
)
|
||||
vllm_result = MULTIMODAL_REGISTRY.map_input(
|
||||
model_config,
|
||||
{"image": asset.pil_image},
|
||||
{"image": image},
|
||||
)
|
||||
|
||||
assert hf_result.keys() == vllm_result.keys()
|
||||
@ -81,36 +83,3 @@ def test_llava_next_image_processor(image_assets, dtype):
|
||||
|
||||
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
|
||||
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Example image pixels were not processed using HuggingFace")
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_image_pixel_types(image_assets, dtype):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL_NAME,
|
||||
tokenizer=MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype=dtype,
|
||||
revision=None,
|
||||
)
|
||||
for asset in image_assets:
|
||||
image_result = MULTIMODAL_REGISTRY.map_input(
|
||||
model_config,
|
||||
{"image": asset.pil_image},
|
||||
)
|
||||
tensor_result = MULTIMODAL_REGISTRY.map_input(
|
||||
model_config,
|
||||
{"image": asset.pil_image},
|
||||
)
|
||||
|
||||
assert image_result.keys() == tensor_result.keys()
|
||||
for key, image_arr in image_result.items():
|
||||
tensor_arr: np.ndarray = tensor_result[key].numpy()
|
||||
|
||||
assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}"
|
||||
assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}"
|
||||
|
@ -5,10 +5,9 @@ from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from PIL import Image
|
||||
|
||||
from vllm.multimodal.utils import ImageFetchAiohttp
|
||||
from vllm.multimodal.utils import ImageFetchAiohttp, fetch_image
|
||||
|
||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||
TEST_IMAGE_URLS = [
|
||||
@ -19,12 +18,9 @@ TEST_IMAGE_URLS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def url_images() -> Dict[str, Image.Image]:
|
||||
return {
|
||||
image_url: await ImageFetchAiohttp.fetch_image(image_url)
|
||||
for image_url in TEST_IMAGE_URLS
|
||||
}
|
||||
@pytest.fixture(scope="module")
|
||||
def url_images() -> Dict[str, Image.Image]:
|
||||
return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS}
|
||||
|
||||
|
||||
def get_supported_suffixes() -> Tuple[str, ...]:
|
||||
@ -41,7 +37,15 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
|
||||
return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
async def test_fetch_image_http(image_url: str):
|
||||
image_sync = fetch_image(image_url)
|
||||
image_async = await ImageFetchAiohttp.fetch_image(image_url)
|
||||
assert _image_equals(image_sync, image_async)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
@pytest.mark.parametrize("suffix", get_supported_suffixes())
|
||||
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
||||
@ -68,8 +72,11 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
||||
base64_image = base64.b64encode(f.read()).decode("utf-8")
|
||||
data_url = f"data:{mime_type};base64,{base64_image}"
|
||||
|
||||
data_image = await ImageFetchAiohttp.fetch_image(data_url)
|
||||
data_image_sync = fetch_image(data_url)
|
||||
if _image_equals(url_image, Image.open(f)):
|
||||
assert _image_equals(url_image, data_image)
|
||||
assert _image_equals(url_image, data_image_sync)
|
||||
else:
|
||||
pass # Lossy format; only check that image can be opened
|
||||
|
||||
data_image_async = await ImageFetchAiohttp.fetch_image(data_url)
|
||||
assert _image_equals(data_image_sync, data_image_async)
|
||||
|
@ -5,7 +5,7 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizerBase
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@ -1303,16 +1303,6 @@ class VisionLanguageConfig:
|
||||
image_input_shape: tuple
|
||||
image_feature_size: int
|
||||
|
||||
#TODO(ywang96): make this a cached property once we refactor the
|
||||
# VisionLanguageConfig class.
|
||||
def get_image_token_text(
|
||||
self, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]:
|
||||
"""Get the image token placeholder text to be inserted into the
|
||||
text prompt and the string representation of the image token id.
|
||||
"""
|
||||
image_token_str = tokenizer.decode(self.image_token_id)
|
||||
return image_token_str * self.image_feature_size, image_token_str
|
||||
|
||||
def as_cli_args_dict(self) -> Dict[str, Any]:
|
||||
"""Flatten vision language config to pure args.
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import codecs
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
|
||||
List, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
@ -10,7 +11,7 @@ from fastapi import Request
|
||||
from openai.types.chat import (ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam)
|
||||
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionContentPartParam, ChatCompletionLogProb,
|
||||
@ -27,8 +28,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import (async_get_and_parse_image,
|
||||
get_full_image_text_prompt)
|
||||
from vllm.multimodal.utils import async_get_and_parse_image
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
@ -97,6 +97,36 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logger.warning(
|
||||
"No chat template provided. Chat API will not work.")
|
||||
|
||||
@cached_property
|
||||
def image_token_str(self) -> Optional[str]:
|
||||
# TODO: Let user specify how to insert image tokens into prompt
|
||||
# (similar to chat template)
|
||||
model_type = self.model_config.hf_config.model_type
|
||||
if model_type == "phi3_v":
|
||||
# Workaround since this token is not defined in the tokenizer
|
||||
return "<|image_1|>"
|
||||
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv",
|
||||
"paligemma"):
|
||||
# These models do not use image tokens in the prompt
|
||||
return None
|
||||
|
||||
# The default behaviour assumes that the image token is
|
||||
# available to the tokenizer.
|
||||
# (Suitable for LLaVA, Idefics2, DeepSeek-VL)
|
||||
vlm_config = self.model_config.multimodal_config
|
||||
if vlm_config is None:
|
||||
raise ValueError(
|
||||
"'image_url' input is not supported as the loaded "
|
||||
"model is not multimodal.")
|
||||
|
||||
image_token_id = vlm_config.image_token_id
|
||||
if vlm_config.image_token_id is None:
|
||||
raise ValueError(
|
||||
"'image_url' input is not supported as the loaded "
|
||||
"model does not specify an image token.")
|
||||
|
||||
return self.tokenizer.decode(image_token_id)
|
||||
|
||||
def _parse_chat_message_content_parts(
|
||||
self,
|
||||
role: str,
|
||||
@ -105,21 +135,26 @@ class OpenAIServingChat(OpenAIServing):
|
||||
texts: List[str] = []
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
|
||||
vlm_config: Optional[VisionLanguageConfig] = getattr(
|
||||
self.engine.engine, "vision_language_config", None)
|
||||
model_config = getattr(self.engine.engine, "model_config", None)
|
||||
|
||||
for part in parts:
|
||||
part_type = part["type"]
|
||||
if part_type == "text":
|
||||
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
||||
texts.append(text)
|
||||
elif part_type == "image_url":
|
||||
if vlm_config is None:
|
||||
raise ValueError(
|
||||
"'image_url' input is not supported as the loaded "
|
||||
"model is not multimodal.")
|
||||
assert self.tokenizer is not None
|
||||
if len(mm_futures) > 0:
|
||||
raise NotImplementedError(
|
||||
"Multiple 'image_url' input is currently not supported."
|
||||
)
|
||||
|
||||
image_token_str = self.image_token_str
|
||||
if image_token_str is not None:
|
||||
if any(image_token_str in text for text in texts):
|
||||
logger.warning(
|
||||
"Detected image token string in the text prompt. "
|
||||
"Skipping prompt formatting.")
|
||||
else:
|
||||
texts.append(image_token_str)
|
||||
|
||||
image_url = cast(ChatCompletionContentPartImageParam,
|
||||
part)["image_url"]
|
||||
|
||||
@ -128,43 +163,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
"'image_url.detail' is currently not supported and "
|
||||
"will be ignored.")
|
||||
|
||||
mm_future = async_get_and_parse_image(image_url["url"])
|
||||
mm_futures.append(mm_future)
|
||||
|
||||
image_future = async_get_and_parse_image(image_url["url"])
|
||||
mm_futures.append(image_future)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
text_prompt = "\n".join(texts)
|
||||
|
||||
if vlm_config is not None and len(mm_futures):
|
||||
|
||||
assert len(
|
||||
mm_futures
|
||||
) == 1, "Multiple 'image_url' input is currently not supported."
|
||||
(image_token_prompt,
|
||||
image_token_str) = vlm_config.get_image_token_text(self.tokenizer)
|
||||
|
||||
# NOTE: If image token string (e.g, <image>) is already present
|
||||
# in the text prompt, we assume it follows the same format required
|
||||
# by the engine.
|
||||
if image_token_str in text_prompt:
|
||||
logger.warning(
|
||||
"Detected image token string in the text prompt. "
|
||||
"Skipping prompt formatting.")
|
||||
messages = [
|
||||
ConversationMessage(role=role, content=text_prompt)
|
||||
]
|
||||
|
||||
else:
|
||||
full_prompt = get_full_image_text_prompt(
|
||||
image_prompt=image_token_prompt,
|
||||
text_prompt=text_prompt,
|
||||
config=model_config)
|
||||
messages = [
|
||||
ConversationMessage(role=role, content=full_prompt)
|
||||
]
|
||||
else:
|
||||
messages = [ConversationMessage(role=role, content=text_prompt)]
|
||||
messages = [ConversationMessage(role=role, content=text_prompt)]
|
||||
|
||||
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
|
||||
|
||||
@ -267,7 +272,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
"prompt": prompt_text,
|
||||
"prompt_token_ids": prompt_ids,
|
||||
}
|
||||
if mm_data is not None:
|
||||
if mm_data:
|
||||
inputs["multi_modal_data"] = mm_data
|
||||
|
||||
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||
|
@ -36,6 +36,7 @@ class OpenAIServing:
|
||||
super().__init__()
|
||||
|
||||
self.engine = engine
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
|
@ -140,7 +140,8 @@ class InputRegistry:
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
|
||||
TODO: Add guide [ref: PR #5276]
|
||||
See also:
|
||||
:ref:`adding_a_new_multimodal_model`
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
@ -8,10 +8,14 @@ from PIL import Image
|
||||
from transformers import CLIPVisionConfig
|
||||
from transformers.models.clip.modeling_clip import CLIPAttention
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
@ -64,6 +68,39 @@ def dummy_image_for_clip(
|
||||
return {"image": image}
|
||||
|
||||
|
||||
def input_processor_for_clip(
|
||||
model_config: ModelConfig,
|
||||
hf_config: CLIPVisionConfig,
|
||||
llm_inputs: LLMInputs,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_clip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
image_token_id=image_token_id,
|
||||
repeat_count=image_feature_size,
|
||||
)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
|
||||
class CLIPVisionEmbeddings(nn.Module):
|
||||
|
||||
|
@ -6,7 +6,7 @@ from transformers import CLIPVisionConfig, LlavaConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -20,8 +20,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
input_processor_for_clip)
|
||||
from .interfaces import SupportsVision
|
||||
from .utils import merge_vision_embeddings
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"language_model.lm_head": "lm_head",
|
||||
@ -51,28 +53,10 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_embeddings: torch.Tensor,
|
||||
image_token_id: int) -> torch.Tensor:
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = (input_ids == image_token_id)
|
||||
|
||||
image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
|
||||
if mask.sum() != image_feature_size:
|
||||
raise ValueError(f"image_feature_size should be {image_feature_size}, "
|
||||
f"but found: {mask.sum()}")
|
||||
|
||||
inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
|
||||
vision_embeddings.shape[-1])
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, num_channels, height, width)"""
|
||||
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||
|
||||
|
||||
LlavaImageInputs = LlavaImagePixelInputs
|
||||
@ -96,8 +80,30 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
vision_config,
|
||||
llm_inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
@ -112,7 +118,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = CLIPVisionModel(config.vision_config)
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -10,7 +10,7 @@ from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -21,13 +21,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
get_clip_patch_grid_length)
|
||||
get_clip_patch_grid_length, input_processor_for_clip)
|
||||
from .interfaces import SupportsVision
|
||||
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
|
||||
from .llava import LlavaMultiModalProjector
|
||||
from .utils import merge_vision_embeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -39,16 +40,27 @@ _KEYS_TO_MODIFY_MAPPING = {
|
||||
|
||||
class LlavaNextImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
|
||||
data: BatchedTensors
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different for each batch.
|
||||
"""
|
||||
|
||||
image_sizes: NotRequired[torch.Tensor]
|
||||
"""Shape: (batch_size, 2)"""
|
||||
"""
|
||||
Shape: `(batch_size, 2)`
|
||||
|
||||
This should be in `(height, width)` format.
|
||||
"""
|
||||
|
||||
|
||||
LlavaNextImageInputs = LlavaNextImagePixelInputs
|
||||
|
||||
|
||||
# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
|
||||
# NOTE: new_height and new_width are further incremented to properly invert the
|
||||
# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
|
||||
def _get_llava_next_num_unpadded_features(
|
||||
height: int,
|
||||
width: int,
|
||||
@ -56,7 +68,6 @@ def _get_llava_next_num_unpadded_features(
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
# Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
@ -64,9 +75,13 @@ def _get_llava_next_num_unpadded_features(
|
||||
current_aspect_ratio: float = current_width / current_height
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = (height * current_width) // width
|
||||
if new_height % 2 == 1:
|
||||
new_height += 1
|
||||
current_height = new_height
|
||||
else:
|
||||
new_width = (width * current_height) // height
|
||||
if new_width % 2 == 1:
|
||||
new_width += 1
|
||||
current_width = new_width
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
@ -74,7 +89,8 @@ def _get_llava_next_num_unpadded_features(
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
|
||||
def _get_llava_next_image_feature_size(
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
|
||||
def get_llava_next_image_feature_size(
|
||||
hf_config: LlavaNextConfig,
|
||||
*,
|
||||
input_height: int,
|
||||
@ -89,7 +105,9 @@ def _get_llava_next_image_feature_size(
|
||||
)
|
||||
base_feature_size = num_patches * num_patches
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
# Note: We follow the "wrong" width/height order
|
||||
# [ref: PR huggingface/transformers#31588]
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_size=(input_height, input_width),
|
||||
grid_pinpoints=hf_config.image_grid_pinpoints,
|
||||
patch_size=vision_config.image_size,
|
||||
@ -110,14 +128,16 @@ def _get_llava_next_image_feature_size(
|
||||
|
||||
|
||||
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
#TODO: change the logic for dummy data to support dynamic shape
|
||||
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
|
||||
image_feature_size = _get_llava_next_image_feature_size(
|
||||
hf_config, input_height=dummy_height, input_width=dummy_width)
|
||||
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
|
||||
dummy_height = dummy_width = 448
|
||||
image_feature_size = get_llava_next_image_feature_size(
|
||||
hf_config,
|
||||
input_height=dummy_height,
|
||||
input_width=dummy_width,
|
||||
)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
@ -139,27 +159,47 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]:
|
||||
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# Temporary patch before dynamic number of image tokens is supported
|
||||
_, _, h, w = ctx.get_multimodal_config().image_input_shape
|
||||
if (w, h) != (image.width, image.height):
|
||||
logger.warning(
|
||||
"Dynamic image shape is currently not supported. "
|
||||
"Resizing input image to (%d, %d).", w, h)
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
width, height = image_data.size
|
||||
|
||||
image = image.resize((w, h))
|
||||
image_feature_size = get_llava_next_image_feature_size(
|
||||
hf_config,
|
||||
input_height=height,
|
||||
input_width=width,
|
||||
)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
return MULTIMODAL_REGISTRY._get_plugin("image") \
|
||||
._default_input_mapper(ctx, image)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
raise TypeError(f"Invalid type for 'image': {type(image)}")
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
vision_config,
|
||||
llm_inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper)
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
@ -172,8 +212,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
self.config = config
|
||||
self.vlm_config = vlm_config
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
@ -196,24 +236,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
|
||||
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
|
||||
_, num_channels, _, _ = self.vlm_config.image_input_shape
|
||||
|
||||
# Note that this is different from that of vLLM vision_language_config
|
||||
# since the image is resized by the HuggingFace preprocessor
|
||||
height = width = self.config.vision_config.image_size
|
||||
|
||||
if list(data.shape[2:]) != [num_channels, height, width]:
|
||||
raise ValueError(
|
||||
f"The expected image tensor shape is batch dimension plus "
|
||||
f"num_patches plus {[num_channels, height, width]}. "
|
||||
f"You supplied {data.shape}. "
|
||||
f"If you are using vLLM's entrypoint, make sure your "
|
||||
f"supplied image input is consistent with "
|
||||
f"image_input_shape in engine args.")
|
||||
|
||||
return data
|
||||
|
||||
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if list(data.shape[1:]) != [2]:
|
||||
raise ValueError(
|
||||
@ -223,14 +245,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
|
||||
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
|
||||
if pixel_values is None or image_sizes is None:
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
@ -240,7 +262,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
return LlavaNextImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_image_pixels(pixel_values),
|
||||
data=pixel_values,
|
||||
image_sizes=self._validate_image_sizes(image_sizes),
|
||||
)
|
||||
|
||||
@ -267,15 +289,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
strategy=self.config.vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
|
||||
def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
|
||||
patch_embeddings: torch.Tensor, *,
|
||||
strategy: str) -> torch.Tensor:
|
||||
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
|
||||
if strategy == "flat":
|
||||
return patch_embeddings.flatten(0, 1)
|
||||
|
||||
if strategy.startswith("spatial"):
|
||||
orig_width, orig_height = image_size
|
||||
height = width = self.config.vision_config.image_size \
|
||||
// self.config.vision_config.patch_size
|
||||
|
||||
@ -289,13 +310,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
other_patch_embeds = patch_embeddings[1:]
|
||||
|
||||
# image_aspect_ratio == "anyres"
|
||||
# Note: We follow the "wrong" width/height order
|
||||
# [ref: PR huggingface/transformers#31588]
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
(orig_width, orig_height),
|
||||
image_size,
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
.view(num_patch_width, num_patch_height, height, width, -1)
|
||||
.view(num_patch_height, num_patch_width, height, width, -1)
|
||||
|
||||
if "unpad" in strategy:
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
@ -333,44 +356,53 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
raise ValueError(f"Unexpected patch merge strategy: {strategy}")
|
||||
|
||||
def _process_image_pixels(
|
||||
self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor:
|
||||
self,
|
||||
inputs: LlavaNextImagePixelInputs,
|
||||
) -> BatchedTensors:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
|
||||
b, num_patches, c, h, w = pixel_values.shape
|
||||
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
|
||||
if isinstance(pixel_values, torch.Tensor):
|
||||
b, num_patches, c, h, w = pixel_values.shape
|
||||
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
|
||||
stacked_image_features = self._image_pixels_to_features(
|
||||
self.vision_tower, stacked_pixel_values)
|
||||
stacked_patch_embeddings = self.multi_modal_projector(
|
||||
stacked_image_features)
|
||||
|
||||
return stacked_patch_embeddings.view(
|
||||
b, num_patches, *stacked_patch_embeddings.shape[1:])
|
||||
|
||||
num_patches_per_batch = [v.shape[0] for v in pixel_values]
|
||||
stacked_pixel_values = torch.cat(pixel_values)
|
||||
stacked_image_features = self._image_pixels_to_features(
|
||||
self.vision_tower, stacked_pixel_values)
|
||||
|
||||
return stacked_image_features.view(b, num_patches,
|
||||
*stacked_image_features.shape[-2:])
|
||||
return [
|
||||
self.multi_modal_projector(image_features) for image_features in
|
||||
torch.split(stacked_image_features, num_patches_per_batch)
|
||||
]
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: LlavaNextImageInputs) -> torch.Tensor:
|
||||
assert self.vision_tower is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
|
||||
patch_embeddings = self.multi_modal_projector(image_features)
|
||||
self, image_input: LlavaNextImageInputs) -> BatchedTensors:
|
||||
patch_embeddings = self._process_image_pixels(image_input)
|
||||
|
||||
image_sizes = image_input.get("image_sizes")
|
||||
if image_sizes is None:
|
||||
batch_size = image_input["data"].shape[0]
|
||||
batch_size = len(image_input["data"])
|
||||
vision_config = self.config.vision_config
|
||||
default_width = default_height = vision_config.image_size
|
||||
image_sizes = torch.as_tensor([[default_width, default_height]
|
||||
default_height = default_width = vision_config.image_size
|
||||
image_sizes = torch.as_tensor([[default_height, default_width]
|
||||
for _ in range(batch_size)])
|
||||
|
||||
merged_patch_embeddings = [
|
||||
return [
|
||||
self._merge_image_patch_embeddings(image_sizes[i],
|
||||
patch_features,
|
||||
patch_features_batch,
|
||||
strategy="spatial_unpad")
|
||||
for i, patch_features in enumerate(patch_embeddings)
|
||||
for i, patch_features_batch in enumerate(patch_embeddings)
|
||||
]
|
||||
|
||||
return torch.stack(merged_patch_embeddings, dim=0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -404,8 +436,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
pixel_values: The pixels in each grid patch for each input image.
|
||||
Expects a batch with shape `[1, num_patches, 3, 336, 336]`.
|
||||
image_sizes: The original `(width, height)` for each input image.
|
||||
Expects a batch with shape `[1, num_patches, 3, h, w]`.
|
||||
image_sizes: The original `(height, width)` for each input image.
|
||||
Expects a batch with shape `[1, 2]`.
|
||||
|
||||
See also:
|
||||
|
@ -13,7 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -22,8 +24,8 @@ from PIL import Image
|
||||
from transformers import CLIPVisionConfig, PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -34,10 +36,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
|
||||
from vllm.multimodal.image import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
input_processor_for_clip)
|
||||
from .interfaces import SupportsVision
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -251,50 +255,22 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
|
||||
class Phi3VImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
|
||||
data: BatchedTensors
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different for each batch.
|
||||
"""
|
||||
|
||||
image_sizes: torch.Tensor
|
||||
"""Shape: (batch_size, 2)"""
|
||||
"""
|
||||
Shape: `(batch_size, 2)`
|
||||
|
||||
This should be in `(height, width)` format.
|
||||
"""
|
||||
|
||||
|
||||
def _get_phi3v_image_feature_size(
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
) -> int:
|
||||
h, w = input_height, input_width
|
||||
|
||||
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L178
|
||||
return (h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12
|
||||
|
||||
|
||||
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
|
||||
#TODO: change the logic for dummy data to support dynamic shape
|
||||
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
|
||||
image_feature_size = _get_phi3v_image_feature_size(
|
||||
input_height=dummy_height,
|
||||
input_width=dummy_width,
|
||||
)
|
||||
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
seq_len,
|
||||
image_token_id=32044,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
mm_data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
image_width_override=dummy_width,
|
||||
image_height_override=dummy_height,
|
||||
)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
|
||||
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
||||
target_height = int(np.ceil(height / padding_unit) * padding_unit)
|
||||
top_padding = int((target_height - height) / 2)
|
||||
@ -304,7 +280,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
||||
return padded_width, padded_height
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
|
||||
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
|
||||
transposed = False
|
||||
if width < height:
|
||||
@ -329,27 +305,133 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
|
||||
return padded_width, padded_height
|
||||
|
||||
|
||||
def _image_processor(ctx: InputContext,
|
||||
image: object) -> Dict[str, torch.Tensor]:
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
|
||||
def get_phi3v_image_feature_size(
|
||||
hf_config: PretrainedConfig,
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
) -> int:
|
||||
num_crops = getattr(hf_config, "num_crops", 16)
|
||||
new_width, new_height = _calc_hd_transform_size(width=input_width,
|
||||
height=input_height,
|
||||
hd_num=num_crops)
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
# Temporary patch before dynamic number of image tokens is supported
|
||||
_, _, h, w = ctx.get_multimodal_config().image_input_shape
|
||||
if (w, h) != _calc_hd_transform_size(width=image.width,
|
||||
height=image.height):
|
||||
logger.warning(
|
||||
"Dynamic image shape is currently not supported. "
|
||||
"Resizing input image to (%d, %d).", w, h)
|
||||
|
||||
image = image.resize((w, h))
|
||||
|
||||
return MULTIMODAL_REGISTRY._get_plugin("image") \
|
||||
._default_input_mapper(ctx, image)
|
||||
raise TypeError(f"Invalid type for 'image': {type(image)}")
|
||||
return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
|
||||
+ (new_height // 336 + 1) * 12
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(_image_processor)
|
||||
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
|
||||
# Result in the max possible feature size (h:w = 16:1)
|
||||
dummy_height, dummy_width = 8000, 50
|
||||
image_feature_size = get_phi3v_image_feature_size(
|
||||
ctx.get_hf_config(PretrainedConfig),
|
||||
input_height=dummy_height,
|
||||
input_width=dummy_width,
|
||||
)
|
||||
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
seq_len,
|
||||
image_token_id=32044,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
mm_data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
image_width_override=dummy_width,
|
||||
image_height_override=dummy_height,
|
||||
)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
# Reserve this function to also handle placeholders for additional images
|
||||
# [ref: PR #5820]
|
||||
@lru_cache
|
||||
def _get_image_placeholder_token_ids(model_config: ModelConfig,
|
||||
idx: int) -> List[int]:
|
||||
assert idx > 0
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
# We need to get the token for "<", not "▁<"
|
||||
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
|
||||
a_token_id, = tokenizer.encode("a", add_special_tokens=False)
|
||||
a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
|
||||
f"a<|image_{idx}|>", add_special_tokens=False)
|
||||
assert a_token_id == a_token_id_
|
||||
|
||||
return image_placeholder_token_ids
|
||||
|
||||
|
||||
def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
w, h = image_data.size
|
||||
w, h = _calc_hd_transform_size(width=w, height=h)
|
||||
|
||||
image_feature_size = get_phi3v_image_feature_size(hf_config,
|
||||
input_width=w,
|
||||
input_height=h)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
prompt = llm_inputs.get("prompt")
|
||||
if prompt is None:
|
||||
new_prompt = None
|
||||
else:
|
||||
if prompt.count("<|image|>") > 0:
|
||||
logger.warning("Please follow the prompt format that is "
|
||||
"documented on HuggingFace which does not involve "
|
||||
"repeating <|image|> tokens.")
|
||||
elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
|
||||
logger.warning("Multiple image input is not supported yet, "
|
||||
"so any extra image tokens will be treated "
|
||||
"as plain text.")
|
||||
|
||||
new_prompt = prompt
|
||||
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
|
||||
|
||||
new_token_ids: List[int] = []
|
||||
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
|
||||
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
|
||||
new_token_ids.append(multimodal_config.image_token_id)
|
||||
|
||||
# No need to further scan the list since we only replace once
|
||||
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
|
||||
break
|
||||
else:
|
||||
new_token_ids.append(prompt_token_ids[i])
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
llm_inputs,
|
||||
image_token_id=multimodal_config.image_token_id,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
|
||||
class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
@ -363,6 +445,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
self.vlm_config = vlm_config
|
||||
|
||||
self.model = LlamaModel(config, cache_config, quant_config)
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
vlm_config, config, self.model.embed_tokens)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
@ -376,12 +460,20 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
|
||||
if pixel_values is not None and image_sizes is not None:
|
||||
return Phi3VImagePixelInputs(type="pixel_values",
|
||||
data=pixel_values,
|
||||
image_sizes=image_sizes)
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
return None
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(image_sizes, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
return Phi3VImagePixelInputs(type="pixel_values",
|
||||
data=pixel_values,
|
||||
image_sizes=image_sizes)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
|
41
vllm/model_executor/models/utils.py
Normal file
41
vllm/model_executor/models/utils.py
Normal file
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
|
||||
from vllm.multimodal import BatchedTensors
|
||||
|
||||
|
||||
def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_embeddings: BatchedTensors,
|
||||
image_token_id: int) -> torch.Tensor:
|
||||
"""
|
||||
Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
|
||||
in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
|
||||
|
||||
Note:
|
||||
This updates `inputs_embeds` in place.
|
||||
"""
|
||||
mask = (input_ids == image_token_id)
|
||||
num_expected_tokens = mask.sum()
|
||||
|
||||
if isinstance(vision_embeddings, torch.Tensor):
|
||||
batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
|
||||
total_tokens = batch_size * batch_tokens
|
||||
if num_expected_tokens != total_tokens:
|
||||
expr = f"{batch_size} x {batch_tokens}"
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {total_tokens} "
|
||||
f"image tokens to {num_expected_tokens} placeholders")
|
||||
|
||||
inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
|
||||
else:
|
||||
size_per_batch = [t.shape[0] for t in vision_embeddings]
|
||||
total_tokens = sum(size_per_batch)
|
||||
if num_expected_tokens != total_tokens:
|
||||
expr = ' + '.join(map(str, size_per_batch))
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {total_tokens} "
|
||||
f"image tokens to {num_expected_tokens} placeholders")
|
||||
|
||||
inputs_embeds[mask] = torch.cat(vision_embeddings)
|
||||
|
||||
return inputs_embeds
|
@ -1,4 +1,5 @@
|
||||
from .base import MultiModalDataDict, MultiModalPlugin
|
||||
from .base import (BatchedTensors, MultiModalDataDict, MultiModalInputs,
|
||||
MultiModalPlugin)
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||
@ -11,8 +12,10 @@ See also:
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"BatchedTensors",
|
||||
"MultiModalDataDict",
|
||||
"MultiModalInputs",
|
||||
"MultiModalPlugin",
|
||||
"MULTIMODAL_REGISTRY",
|
||||
"MultiModalRegistry",
|
||||
"MultiModalDataDict",
|
||||
]
|
||||
|
@ -1,23 +1,90 @@
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Type,
|
||||
TypedDict, TypeVar, Union)
|
||||
from collections import UserDict, defaultdict
|
||||
from typing import (Any, Callable, Dict, List, Optional, Type, TypedDict,
|
||||
TypeVar, Union)
|
||||
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
N = TypeVar("N", bound=Type["nn.Module"])
|
||||
BatchedTensors = Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
If each input tensor in the batch has the same size, this is a single batched
|
||||
tensor; otherwise, this is a list of tensors with one element per batch.
|
||||
"""
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
# UserDict cannot be subscripted
|
||||
class _MultiModalInputsBase(UserDict):
|
||||
pass
|
||||
else:
|
||||
|
||||
class _MultiModalInputsBase(UserDict[str, torch.Tensor]):
|
||||
pass
|
||||
|
||||
|
||||
class MultiModalInputs(_MultiModalInputsBase):
|
||||
"""
|
||||
A dictionary that represents the keyword arguments to
|
||||
:meth:`~torch.nn.Module.forward`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def try_concat(
|
||||
tensors: List[torch.Tensor],
|
||||
*,
|
||||
device: torch.types.Device,
|
||||
) -> BatchedTensors:
|
||||
# Avoid initializing CUDA too early
|
||||
import torch
|
||||
|
||||
unbatched_shape = tensors[0].shape[1:]
|
||||
|
||||
for tensor in tensors:
|
||||
if tensor.shape[1:] != unbatched_shape:
|
||||
return [
|
||||
tensor.squeeze(0).to(device=device) for tensor in tensors
|
||||
]
|
||||
|
||||
return torch.cat(tensors, dim=0).to(device=device)
|
||||
|
||||
@staticmethod
|
||||
def batch(
|
||||
inputs_list: List["MultiModalInputs"],
|
||||
device: torch.types.Device,
|
||||
) -> Dict[str, BatchedTensors]:
|
||||
"""Batch multiple inputs together into a dictionary."""
|
||||
if len(inputs_list) == 0:
|
||||
return {}
|
||||
|
||||
keys = inputs_list[0].keys()
|
||||
|
||||
item_lists: Dict[str, List[torch.Tensor]] = defaultdict(list)
|
||||
|
||||
for inputs in inputs_list:
|
||||
if inputs.keys() != keys:
|
||||
msg = f"Inputs do not share the same keys ({keys})"
|
||||
raise ValueError(msg)
|
||||
|
||||
for k, v in inputs.items():
|
||||
item_lists[k].append(v)
|
||||
|
||||
return {
|
||||
k: MultiModalInputs.try_concat(item_list, device=device)
|
||||
for k, item_list in item_lists.items()
|
||||
}
|
||||
|
||||
|
||||
class MultiModalDataBuiltins(TypedDict, total=False):
|
||||
image: "Image.Image"
|
||||
image: Image.Image
|
||||
|
||||
|
||||
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
|
||||
@ -29,12 +96,13 @@ to the model by the corresponding mapper. By default, the mapper of
|
||||
the corresponding plugin with the same modality key is applied.
|
||||
"""
|
||||
|
||||
MultiModalInputMapper = Callable[[InputContext, object], Dict[str,
|
||||
"torch.Tensor"]]
|
||||
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
|
||||
"""Return a dictionary to be passed as keyword arguments to
|
||||
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
|
||||
and processors in HuggingFace Transformers."""
|
||||
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
|
||||
|
||||
class MultiModalPlugin(ABC):
|
||||
"""
|
||||
@ -48,8 +116,7 @@ class MultiModalPlugin(ABC):
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._input_mappers: Dict[Type["nn.Module"],
|
||||
MultiModalInputMapper] = {}
|
||||
self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
|
||||
|
||||
@abstractmethod
|
||||
def get_data_key(self) -> str:
|
||||
@ -60,7 +127,7 @@ class MultiModalPlugin(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _default_input_mapper(self, ctx: InputContext,
|
||||
data: object) -> Dict[str, "torch.Tensor"]:
|
||||
data: object) -> MultiModalInputs:
|
||||
"""Return a dictionary to be passed as keyword arguments to
|
||||
:meth:`~torch.nn.Module.forward`. This is similar in concept to
|
||||
tokenizers and processors in HuggingFace Transformers.
|
||||
@ -80,6 +147,7 @@ class MultiModalPlugin(ABC):
|
||||
|
||||
See also:
|
||||
:ref:`input_processing_pipeline`
|
||||
:ref:`adding_a_new_multimodal_model`
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
@ -97,7 +165,7 @@ class MultiModalPlugin(ABC):
|
||||
return wrapper
|
||||
|
||||
def map_input(self, model_config: ModelConfig,
|
||||
data: object) -> Dict[str, "torch.Tensor"]:
|
||||
data: object) -> MultiModalInputs:
|
||||
"""
|
||||
Apply an input mapper to a data passed
|
||||
to the model, transforming the data into a dictionary of model inputs.
|
||||
@ -106,7 +174,8 @@ class MultiModalPlugin(ABC):
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
|
||||
TODO: Add guide [ref: PR #5276]
|
||||
See also:
|
||||
:ref:`adding_a_new_multimodal_model`
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
@ -1,19 +1,102 @@
|
||||
from functools import lru_cache
|
||||
from typing import Dict
|
||||
from typing import List, Optional, Tuple, TypeVar
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.image_processor import get_image_processor
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from .base import MultiModalPlugin
|
||||
from .base import MultiModalInputs, MultiModalPlugin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
cached_get_image_processor = lru_cache(get_image_processor)
|
||||
cached_get_tokenizer = lru_cache(get_tokenizer)
|
||||
|
||||
# Utilities for image input processors
|
||||
_T = TypeVar("_T", str, int)
|
||||
|
||||
|
||||
def repeat_and_pad_token(
|
||||
token: _T,
|
||||
*,
|
||||
repeat_count: int = 1,
|
||||
pad_token_left: Optional[_T] = None,
|
||||
pad_token_right: Optional[_T] = None,
|
||||
) -> List[_T]:
|
||||
replacement = [token] * repeat_count
|
||||
if pad_token_left is not None:
|
||||
replacement = [pad_token_left] + replacement
|
||||
if pad_token_right is not None:
|
||||
replacement = replacement + [pad_token_right]
|
||||
|
||||
return replacement
|
||||
|
||||
|
||||
def repeat_and_pad_image_tokens(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: List[int],
|
||||
*,
|
||||
image_token_id: int,
|
||||
repeat_count: int = 1,
|
||||
pad_token_left: Optional[int] = None,
|
||||
pad_token_right: Optional[int] = None,
|
||||
) -> Tuple[Optional[str], List[int]]:
|
||||
if prompt is None:
|
||||
new_prompt = None
|
||||
else:
|
||||
image_token_str = tokenizer.decode(image_token_id)
|
||||
pad_token_str_left = (None if pad_token_left is None else
|
||||
tokenizer.decode(pad_token_left))
|
||||
pad_token_str_right = (None if pad_token_right is None else
|
||||
tokenizer.decode(pad_token_right))
|
||||
replacement_str = "".join(
|
||||
repeat_and_pad_token(
|
||||
image_token_str,
|
||||
repeat_count=repeat_count,
|
||||
pad_token_left=pad_token_str_left,
|
||||
pad_token_right=pad_token_str_right,
|
||||
))
|
||||
|
||||
image_token_count = prompt.count(image_token_str)
|
||||
# This is an arbitrary number to distinguish between the two cases
|
||||
if image_token_count > 16:
|
||||
logger.warning(
|
||||
"Please follow the prompt format that is "
|
||||
"documented on HuggingFace which does not involve "
|
||||
"repeating %s tokens.", image_token_str)
|
||||
elif image_token_count > 1:
|
||||
logger.warning("Multiple image input is not supported yet, "
|
||||
"so any extra image tokens will be treated "
|
||||
"as plain text.")
|
||||
|
||||
# The image tokens are removed to be consistent with HuggingFace
|
||||
new_prompt = prompt.replace(image_token_str, replacement_str, 1)
|
||||
|
||||
new_token_ids: List[int] = []
|
||||
for i, token in enumerate(prompt_token_ids):
|
||||
if token == image_token_id:
|
||||
replacement_ids = repeat_and_pad_token(
|
||||
image_token_id,
|
||||
repeat_count=repeat_count,
|
||||
pad_token_left=pad_token_left,
|
||||
pad_token_right=pad_token_right,
|
||||
)
|
||||
new_token_ids.extend(replacement_ids)
|
||||
|
||||
# No need to further scan the list since we only replace once
|
||||
new_token_ids.extend(prompt_token_ids[i + 1:])
|
||||
break
|
||||
else:
|
||||
new_token_ids.append(token)
|
||||
|
||||
return new_prompt, new_token_ids
|
||||
|
||||
|
||||
class ImagePlugin(MultiModalPlugin):
|
||||
@ -27,7 +110,7 @@ class ImagePlugin(MultiModalPlugin):
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
def _default_input_mapper(self, ctx: InputContext,
|
||||
data: object) -> Dict[str, torch.Tensor]:
|
||||
data: object) -> MultiModalInputs:
|
||||
model_config = ctx.model_config
|
||||
if isinstance(data, Image.Image):
|
||||
image_processor = self._get_hf_image_processor(model_config)
|
||||
@ -35,10 +118,15 @@ class ImagePlugin(MultiModalPlugin):
|
||||
raise RuntimeError("No HuggingFace processor is available"
|
||||
"to process the image object")
|
||||
try:
|
||||
return image_processor.preprocess(data, return_tensors="pt") \
|
||||
.to(model_config.dtype).data
|
||||
batch_data = image_processor \
|
||||
.preprocess(data, return_tensors="pt") \
|
||||
.data
|
||||
except Exception:
|
||||
logger.error("Failed to process image (%s)", data)
|
||||
raise
|
||||
|
||||
raise TypeError(f"Invalid type for 'image': {type(data)}")
|
||||
return MultiModalInputs(batch_data)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
|
||||
raise TypeError(f"Invalid image type: {type(data)}")
|
||||
|
@ -1,18 +1,17 @@
|
||||
import functools
|
||||
from typing import Optional, Sequence, Type, TypeVar
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base import MultiModalDataDict, MultiModalInputMapper, MultiModalPlugin
|
||||
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
|
||||
MultiModalPlugin)
|
||||
from .image import ImagePlugin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
|
||||
|
||||
class MultiModalRegistry:
|
||||
"""
|
||||
@ -61,7 +60,7 @@ class MultiModalRegistry:
|
||||
return self.register_input_mapper("image", mapper)
|
||||
|
||||
def _process_input(self, key: str, value: object,
|
||||
model_config: ModelConfig):
|
||||
model_config: ModelConfig) -> MultiModalInputs:
|
||||
plugin = self._plugins.get(key)
|
||||
if plugin:
|
||||
return plugin.map_input(model_config, value)
|
||||
@ -93,16 +92,28 @@ class MultiModalRegistry:
|
||||
"""
|
||||
return self.register_input_mapper("image", mapper)
|
||||
|
||||
def map_input(self, model_config: ModelConfig, data: MultiModalDataDict):
|
||||
def map_input(self, model_config: ModelConfig,
|
||||
data: MultiModalDataDict) -> MultiModalInputs:
|
||||
"""
|
||||
Apply an input mapper to the data passed to the model.
|
||||
|
||||
See :meth:`MultiModalPlugin.map_input` for more details.
|
||||
"""
|
||||
result_list = [
|
||||
self._process_input(k, v, model_config) for k, v in data.items()
|
||||
]
|
||||
return {k: v for d in result_list for k, v in d.items()}
|
||||
merged_dict: Dict[str, torch.Tensor] = {}
|
||||
|
||||
for data_key, data_value in data.items():
|
||||
input_dict = self._process_input(data_key, data_value,
|
||||
model_config)
|
||||
|
||||
for input_key, input_tensor in input_dict.items():
|
||||
if input_key in merged_dict:
|
||||
raise ValueError(f"The input mappers (keys={set(data)}) "
|
||||
f"resulted in a conflicting keyword "
|
||||
f"argument to `forward()`: {input_key}")
|
||||
|
||||
merged_dict[input_key] = input_tensor
|
||||
|
||||
return MultiModalInputs(merged_dict)
|
||||
|
||||
def create_input_mapper(self, model_config: ModelConfig):
|
||||
"""
|
||||
|
@ -4,11 +4,56 @@ from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
|
||||
from vllm.multimodal.base import MultiModalDataDict
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
|
||||
def _validate_remote_url(url: str, *, name: str):
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.scheme not in ["http", "https"]:
|
||||
raise ValueError(f"Invalid '{name}': A valid '{name}' "
|
||||
"must have scheme 'http' or 'https'.")
|
||||
|
||||
|
||||
def _get_request_headers():
|
||||
return {"User-Agent": f"vLLM/{VLLM_VERSION}"}
|
||||
|
||||
|
||||
def _load_image_from_bytes(b: bytes):
|
||||
image = Image.open(BytesIO(b))
|
||||
image.load()
|
||||
return image
|
||||
|
||||
|
||||
def _load_image_from_data_url(image_url: str):
|
||||
# Only split once and assume the second part is the base64 encoded image
|
||||
_, image_base64 = image_url.split(",", 1)
|
||||
return load_image_from_base64(image_base64)
|
||||
|
||||
|
||||
def fetch_image(image_url: str) -> Image.Image:
|
||||
"""Load PIL image from a url or base64 encoded openai GPT4V format"""
|
||||
if image_url.startswith('http'):
|
||||
_validate_remote_url(image_url, name="image_url")
|
||||
|
||||
headers = _get_request_headers()
|
||||
|
||||
with requests.get(url=image_url, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
image_raw = response.content
|
||||
image = _load_image_from_bytes(image_raw)
|
||||
|
||||
elif image_url.startswith('data:image'):
|
||||
image = _load_image_from_data_url(image_url)
|
||||
else:
|
||||
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
|
||||
"with either 'data:image' or 'http'.")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class ImageFetchAiohttp:
|
||||
@ -29,34 +74,31 @@ class ImageFetchAiohttp:
|
||||
"""Load PIL image from a url or base64 encoded openai GPT4V format"""
|
||||
|
||||
if image_url.startswith('http'):
|
||||
parsed_url = urlparse(image_url)
|
||||
if parsed_url.scheme not in ["http", "https"]:
|
||||
raise ValueError("Invalid 'image_url': A valid 'image_url' "
|
||||
"must have scheme 'http' or 'https'.")
|
||||
# Avoid circular import
|
||||
from vllm import __version__ as VLLM_VERSION
|
||||
_validate_remote_url(image_url, name="image_url")
|
||||
|
||||
client = cls.get_aiohttp_client()
|
||||
headers = {"User-Agent": f"vLLM/{VLLM_VERSION}"}
|
||||
headers = _get_request_headers()
|
||||
|
||||
async with client.get(url=image_url, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
image_raw = await response.read()
|
||||
image = Image.open(BytesIO(image_raw))
|
||||
image = _load_image_from_bytes(image_raw)
|
||||
|
||||
# Only split once and assume the second part is the base64 encoded image
|
||||
elif image_url.startswith('data:image'):
|
||||
image = load_image_from_base64(image_url.split(',', 1)[1])
|
||||
|
||||
image = _load_image_from_data_url(image_url)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid 'image_url': A valid 'image_url' must start "
|
||||
"with either 'data:image' or 'http'.")
|
||||
|
||||
image.load()
|
||||
return image
|
||||
|
||||
|
||||
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
|
||||
image = await ImageFetchAiohttp.fetch_image(image_url)
|
||||
return {"image": image}
|
||||
|
||||
|
||||
def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
|
||||
"""Encode a pillow image to base64 format."""
|
||||
|
||||
@ -69,26 +111,11 @@ def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
|
||||
|
||||
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
|
||||
"""Load image from base64 format."""
|
||||
return Image.open(BytesIO(base64.b64decode(image)))
|
||||
return _load_image_from_bytes(base64.b64decode(image))
|
||||
|
||||
|
||||
# TODO(ywang96): move this to a model registry for preprocessing vision
|
||||
# language prompts based on the model type.
|
||||
def get_full_image_text_prompt(image_prompt: str, text_prompt: str,
|
||||
config: ModelConfig) -> str:
|
||||
"""Combine image and text prompts for vision language model depending on
|
||||
the model architecture."""
|
||||
|
||||
if config.hf_config.model_type in ("llava", "llava_next"):
|
||||
full_prompt = f"{image_prompt}\n{text_prompt}"
|
||||
elif config.hf_config.model_type == 'phi3_v':
|
||||
full_prompt = f"{image_prompt}<s>\n{text_prompt}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported model type: {config.hf_config.model_type}")
|
||||
return full_prompt
|
||||
|
||||
|
||||
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
|
||||
image = await ImageFetchAiohttp.fetch_image(image_url)
|
||||
return {"image": image}
|
||||
def rescale_image_size(image: Image.Image, size_factor: float) -> Image.Image:
|
||||
"""Rescale the dimensions of an image by a constant factor."""
|
||||
new_width = int(image.width * size_factor)
|
||||
new_height = int(image.height * size_factor)
|
||||
return image.resize((new_width, new_height))
|
||||
|
@ -457,7 +457,7 @@ class SequenceGroup:
|
||||
return next(iter(self.seqs_dict.values())).prompt_token_ids
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> Optional["MultiModalDataDict"]:
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
# All sequences in the group should have the same multi-modal data.
|
||||
# We use the multi-modal data of an arbitrary sequence.
|
||||
return next(iter(self.seqs_dict.values())).multi_modal_data
|
||||
|
@ -1,9 +1,4 @@
|
||||
from transformers import AutoImageProcessor
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
from typing import cast
|
||||
|
||||
|
||||
def get_image_processor(
|
||||
@ -11,10 +6,15 @@ def get_image_processor(
|
||||
*args,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
) -> BaseImageProcessor:
|
||||
):
|
||||
"""Gets an image processor for the given model name via HuggingFace."""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoImageProcessor
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
try:
|
||||
processor: BaseImageProcessor = AutoImageProcessor.from_pretrained(
|
||||
processor = AutoImageProcessor.from_pretrained(
|
||||
processor_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@ -34,4 +34,4 @@ def get_image_processor(
|
||||
else:
|
||||
raise e
|
||||
|
||||
return processor
|
||||
return cast(BaseImageProcessor, processor)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
|
||||
Type, Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -12,7 +12,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
||||
MultiModalInputs)
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
@ -40,7 +41,7 @@ class CPUModelInput(ModelRunnerInputBase):
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
@ -132,15 +133,14 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], Dict[
|
||||
str, torch.Tensor]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||
Mapping[str, BatchedTensors]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
multi_modal_kwargs_list: Dict[str,
|
||||
List[torch.Tensor]] = defaultdict(list)
|
||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
@ -162,10 +162,9 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
input_positions.extend(list(range(computed_len, seq_len)))
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data is not None:
|
||||
if mm_data:
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
for k, v in mm_kwargs.items():
|
||||
multi_modal_kwargs_list[k].append(v)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
@ -189,11 +188,6 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
multi_modal_kwargs = {
|
||||
k: torch.cat(v, dim=0).to(self.device)
|
||||
for k, v in multi_modal_kwargs_list.items()
|
||||
}
|
||||
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
@ -217,6 +211,10 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
block_tables=torch.tensor([]),
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
||||
device=self.device)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs)
|
||||
|
||||
@ -367,10 +365,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
"positions": model_input.input_positions,
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": model_input.attn_metadata,
|
||||
**(model_input.multi_modal_kwargs or {}),
|
||||
}
|
||||
if (self.vision_language_config
|
||||
and model_input.multi_modal_kwargs is not None):
|
||||
execute_model_kwargs.update(model_input.multi_modal_kwargs)
|
||||
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
|
@ -92,10 +92,9 @@ class EmbeddingModelRunner(
|
||||
"positions": model_input.input_positions,
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": model_input.attn_metadata,
|
||||
**(model_input.multi_modal_kwargs or {}),
|
||||
}
|
||||
if self.vision_language_config:
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
execute_model_kwargs.update({"image_input": multi_modal_kwargs})
|
||||
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Only perform pooling in the driver worker.
|
||||
|
@ -3,8 +3,8 @@ import gc
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
|
||||
TypeVar, Union)
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
||||
Tuple, Type, TypeVar, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -37,7 +37,8 @@ from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models.interfaces import supports_lora
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
||||
MultiModalInputs)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
@ -83,7 +84,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
lora_mapping: Optional["LoRAMapping"] = None
|
||||
lora_requests: Optional[Set[LoRARequest]] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
|
||||
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
virtual_engine: int = 0
|
||||
@ -356,8 +357,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
context_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
multi_modal_kwargs_list: Dict[str,
|
||||
List[torch.Tensor]] = defaultdict(list)
|
||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
||||
request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
|
||||
decode_only = True
|
||||
num_prefills = 0
|
||||
@ -528,8 +528,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
if mm_data:
|
||||
# Process multi-modal data
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
for k, v in mm_kwargs.items():
|
||||
multi_modal_kwargs_list[k].append(v)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
is_profile_run = _is_block_tables_empty(
|
||||
seq_group_metadata.block_tables)
|
||||
@ -746,10 +745,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
else:
|
||||
lora_mapping = None
|
||||
|
||||
multi_modal_kwargs = {
|
||||
k: torch.cat(v, dim=0).to(self.device)
|
||||
for k, v in multi_modal_kwargs_list.items()
|
||||
}
|
||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
||||
device=self.device)
|
||||
request_ids_to_seq_ids = {
|
||||
seq_group_metadata.request_id:
|
||||
list(seq_group_metadata.seq_data.keys())
|
||||
@ -821,7 +818,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
|
||||
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
|
||||
.dummy_data_for_profiling(model_config, seq_len)
|
||||
assert len(seq_data.prompt_token_ids) == seq_len
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
assert len(seq_data.prompt_token_ids) >= seq_len, (
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but got: {len(seq_data.prompt_token_ids)}")
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
|
@ -1,5 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -9,6 +10,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader.neuron import get_neuron_model
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
||||
MultiModalInputs)
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
@ -29,6 +32,7 @@ class ModelInputForNeuron(ModelRunnerInputBase):
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
input_block_ids: Optional[torch.Tensor] = None
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
@ -65,6 +69,10 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
|
||||
# Multi-modal data support
|
||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
||||
.create_input_mapper(self.model_config)
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # initialize after load_model.
|
||||
|
||||
@ -76,13 +84,15 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], Mapping[
|
||||
str, BatchedTensors]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
input_block_ids: List[int] = []
|
||||
|
||||
seq_lens: List[int] = []
|
||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
@ -102,6 +112,12 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
assert len(block_table) == 1
|
||||
input_block_ids.append(block_table[0])
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
# Process multi-modal data
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
max_seq_len = max(seq_lens)
|
||||
assert max_seq_len > 0
|
||||
input_tokens = make_tensor_with_pad(input_tokens,
|
||||
@ -118,7 +134,11 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
return input_tokens, input_positions, input_block_ids, seq_lens
|
||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
||||
device=self.device)
|
||||
|
||||
return (input_tokens, input_positions, input_block_ids, seq_lens,
|
||||
multi_modal_kwargs)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
@ -184,8 +204,9 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, input_block_ids,
|
||||
seq_lens) = self._prepare_prompt(seq_group_metadata_list)
|
||||
(input_tokens, input_positions, input_block_ids, seq_lens,
|
||||
multi_modal_kwargs
|
||||
) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
|
||||
@ -203,7 +224,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
return ModelInputForNeuron(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
input_block_ids=input_block_ids,
|
||||
sampling_metadata=sampling_metadata)
|
||||
sampling_metadata=sampling_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@ -221,6 +243,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
**(model_input.multi_modal_kwargs or {}),
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, NamedTuple, Optional, Tuple
|
||||
from typing import List, Mapping, NamedTuple, Optional, Tuple
|
||||
|
||||
import openvino as ov
|
||||
import torch
|
||||
@ -12,6 +12,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader.openvino import get_model
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
||||
MultiModalInputs)
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -23,7 +25,7 @@ class ModelInput(NamedTuple):
|
||||
attn_metadata: Optional[OpenVINOAttentionMetadata]
|
||||
seq_lens: List[int]
|
||||
query_lens: List[int]
|
||||
multi_modal_input: Optional[torch.Tensor]
|
||||
multi_modal_kwargs: Mapping[str, BatchedTensors]
|
||||
|
||||
@classmethod
|
||||
def empty(cls, device):
|
||||
@ -32,7 +34,7 @@ class ModelInput(NamedTuple):
|
||||
attn_metadata=None,
|
||||
seq_lens=[],
|
||||
query_lens=[],
|
||||
multi_modal_input=None)
|
||||
multi_modal_kwargs={})
|
||||
|
||||
|
||||
class OpenVINOModelRunner:
|
||||
@ -78,6 +80,10 @@ class OpenVINOModelRunner:
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
# Multi-modal data support
|
||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
||||
.create_input_mapper(self.model_config)
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
|
||||
@ -108,6 +114,8 @@ class OpenVINOModelRunner:
|
||||
seq_lens: List[int] = []
|
||||
past_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
||||
|
||||
subsequence_begins: List[int] = []
|
||||
block_indices: List[int] = []
|
||||
block_indices_begins: List[int] = []
|
||||
@ -160,6 +168,11 @@ class OpenVINOModelRunner:
|
||||
and self.sliding_window is None
|
||||
and is_prompt)
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
@ -251,22 +264,24 @@ class OpenVINOModelRunner:
|
||||
block_indices_begins=block_indices_begins_tensor,
|
||||
max_context_len=max_context_len_tensor,
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
||||
device=self.device)
|
||||
|
||||
return ModelInput(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
seq_lens,
|
||||
query_lens,
|
||||
None,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
)
|
||||
|
||||
def prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
|
||||
SamplingMetadata, Optional[torch.Tensor], ]:
|
||||
multi_modal_input = None
|
||||
|
||||
SamplingMetadata, Mapping[str, BatchedTensors]]:
|
||||
# Prepare input tensors.
|
||||
(
|
||||
input_tokens,
|
||||
@ -274,7 +289,7 @@ class OpenVINOModelRunner:
|
||||
attn_metadata,
|
||||
seq_lens,
|
||||
query_lens,
|
||||
multi_modal_input,
|
||||
multi_modal_kwargs,
|
||||
) = self._prepare_model_input(seq_group_metadata_list)
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
@ -290,7 +305,7 @@ class OpenVINOModelRunner:
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
sampling_metadata,
|
||||
multi_modal_input,
|
||||
multi_modal_kwargs,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -304,7 +319,7 @@ class OpenVINOModelRunner:
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
sampling_metadata,
|
||||
multi_modal_input,
|
||||
multi_modal_kwargs,
|
||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
model_executable = self.model
|
||||
@ -313,9 +328,8 @@ class OpenVINOModelRunner:
|
||||
"positions": input_positions,
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": attn_metadata,
|
||||
**(multi_modal_kwargs or {}),
|
||||
}
|
||||
if self.vision_language_config:
|
||||
execute_model_kwargs.update({"image_input": multi_modal_input})
|
||||
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Mapping, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -12,6 +12,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
||||
MultiModalInputs)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceOutput)
|
||||
@ -66,6 +68,10 @@ class TPUModelRunner:
|
||||
False,
|
||||
)
|
||||
|
||||
# Multi-modal data support
|
||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
||||
.create_input_mapper(self.model_config)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.device = self.device_config.device
|
||||
|
||||
@ -193,12 +199,14 @@ class TPUModelRunner:
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
|
||||
Mapping[str, BatchedTensors]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
prompt_lens: List[int] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
@ -224,6 +232,11 @@ class TPUModelRunner:
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping[-1].append(slot)
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
assert len(prompt_lens) > 0
|
||||
num_prefills = len(prompt_lens)
|
||||
num_prefill_tokens = sum(prompt_lens)
|
||||
@ -261,17 +274,24 @@ class TPUModelRunner:
|
||||
block_tables=None,
|
||||
context_lens=None,
|
||||
)
|
||||
return input_tokens, input_positions, attn_metadata, prompt_lens
|
||||
|
||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
||||
device=self.device)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, prompt_lens,
|
||||
multi_modal_kwargs)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
|
||||
Mapping[str, BatchedTensors]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
context_lens: List[int] = []
|
||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
||||
|
||||
batch_idx = 0
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
@ -297,6 +317,11 @@ class TPUModelRunner:
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot])
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
batch_size = _get_padded_batch_size(batch_idx)
|
||||
num_paddings = batch_size - batch_idx
|
||||
input_tokens = input_tokens + [[0]] * num_paddings
|
||||
@ -330,7 +355,12 @@ class TPUModelRunner:
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
)
|
||||
return input_tokens, input_positions, attn_metadata, input_lens
|
||||
|
||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
||||
device=self.device)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, input_lens,
|
||||
multi_modal_kwargs)
|
||||
|
||||
def _prepare_sample(
|
||||
self,
|
||||
@ -483,6 +513,7 @@ class ModelWrapper(nn.Module):
|
||||
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_lens: torch.Tensor,
|
||||
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
|
||||
t: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
num_samples: int,
|
||||
@ -496,6 +527,8 @@ class ModelWrapper(nn.Module):
|
||||
memory profiling at initialization.
|
||||
attn_metadata: The Pallas attention metadata.
|
||||
input_lens: The actual input lengths of shape [batch_size].
|
||||
multi_modal_kwargs: Keyword arguments from multi-modal data to
|
||||
pass to the model.
|
||||
t: The sampling temperature of shape [batch_size].
|
||||
p: The top-p probability of shape [batch_size].
|
||||
"""
|
||||
@ -540,6 +573,7 @@ class ModelWrapper(nn.Module):
|
||||
position_ids,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
**(multi_modal_kwargs or {}),
|
||||
)
|
||||
hidden_states = hidden_states.flatten(0, 1)
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
|
@ -1,5 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
|
||||
Type, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -9,10 +10,13 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
||||
MultiModalInputs)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData,
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
|
||||
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
|
||||
@ -44,7 +48,7 @@ class ModelInputForXPU(ModelRunnerInputBase):
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
multi_modal_input: Optional[Dict[str, torch.Tensor]] = None
|
||||
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
@ -116,6 +120,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
# Multi-modal data support
|
||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
||||
.create_input_mapper(self.model_config)
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
|
||||
@ -156,12 +164,26 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
# To exercise the worst scenario for GPU memory consumption,
|
||||
# the number of seqs (batch_size) is chosen to maximize the number
|
||||
# of images processed.
|
||||
model_config = self.model_config
|
||||
vlm_config = self.vision_language_config
|
||||
|
||||
if vlm_config:
|
||||
max_num_seqs = min(
|
||||
max_num_seqs,
|
||||
int(max_num_batched_tokens / vlm_config.image_feature_size))
|
||||
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
|
||||
seq_data = SequenceData([0] * seq_len)
|
||||
dummy_multi_modal_data = None
|
||||
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
|
||||
.dummy_data_for_profiling(model_config, seq_len)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
assert len(seq_data.prompt_token_ids) >= seq_len, (
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but got: {len(seq_data.prompt_token_ids)}")
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
@ -194,7 +216,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForXPU:
|
||||
multi_modal_input = None
|
||||
multi_modal_kwargs = None
|
||||
if self.is_driver_worker:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
@ -202,7 +224,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_input
|
||||
multi_modal_kwargs
|
||||
) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
@ -223,6 +245,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
"input_positions": input_positions,
|
||||
"selected_token_indices":
|
||||
sampling_metadata.selected_token_indices,
|
||||
"multi_modal_kwargs": multi_modal_kwargs,
|
||||
}
|
||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
@ -232,6 +255,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
input_positions = metadata_dict.pop("input_positions")
|
||||
selected_token_indices = metadata_dict.pop(
|
||||
"selected_token_indices")
|
||||
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
|
||||
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
@ -244,7 +268,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
input_positions=input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
multi_modal_input=multi_modal_input)
|
||||
multi_modal_kwargs=multi_modal_kwargs)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
@ -350,10 +374,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
"positions": model_input.input_positions,
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": model_input.attn_metadata,
|
||||
**(model_input.multi_modal_kwargs or {}),
|
||||
}
|
||||
if self.vision_language_config:
|
||||
execute_model_kwargs.update(
|
||||
{"image_input": model_input.multi_modal_input})
|
||||
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
@ -376,13 +398,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||
Optional[torch.Tensor]]:
|
||||
Mapping[str, BatchedTensors]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
multi_modal_input_list: List[torch.Tensor] = []
|
||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
@ -403,9 +425,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
# is always the first token in the sequence.
|
||||
input_positions.extend(list(range(computed_len, seq_len)))
|
||||
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
multi_modal_input_list.append(
|
||||
seq_group_metadata.multi_modal_data.data)
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
if seq_group_metadata.block_tables is None:
|
||||
# During memory profiling, the block tables are not initialized
|
||||
@ -435,15 +458,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if multi_modal_input_list:
|
||||
assert self.vision_language_config, (
|
||||
"Multi-modal inputs are only supported by "
|
||||
"vision language models.")
|
||||
multi_modal_input = torch.cat(multi_modal_input_list,
|
||||
dim=0).to(self.device)
|
||||
else:
|
||||
multi_modal_input = None
|
||||
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
@ -475,5 +489,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
num_decode_tokens=0,
|
||||
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
||||
device=self.device)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_input)
|
||||
multi_modal_kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user