[Core] Support image processor (#4197)

This commit is contained in:
Cyrus Leung 2024-06-03 13:56:41 +08:00 committed by GitHub
parent dfbe60dc62
commit 7a64d24aad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1042 additions and 256 deletions

View File

@ -37,6 +37,7 @@ jobs:
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/multimodal --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml

View File

@ -90,6 +90,7 @@ autodoc_mock_imports = [
"sentencepiece",
"vllm.cuda_utils",
"vllm._C",
"PIL",
"numpy",
"tqdm",
"tensorizer",
@ -116,12 +117,13 @@ class MockedClassDocumenter(autodoc.ClassDocumenter):
autodoc.ClassDocumenter = MockedClassDocumenter
intersphinx_mapping = {
'python': ('https://docs.python.org/3', None),
'typing_extensions':
('https://typing-extensions.readthedocs.io/en/latest', None),
'numpy': ('https://numpy.org/doc/stable', None),
'torch': ('https://pytorch.org/docs/stable', None),
'psutil': ('https://psutil.readthedocs.io/en/stable', None),
"python": ("https://docs.python.org/3", None),
"typing_extensions":
("https://typing-extensions.readthedocs.io/en/latest", None),
"pillow": ("https://pillow.readthedocs.io/en/stable", None),
"numpy": ("https://numpy.org/doc/stable", None),
"torch": ("https://pytorch.org/docs/stable", None),
"psutil": ("https://psutil.readthedocs.io/en/stable", None),
}
autodoc_preserve_defaults = True

View File

@ -0,0 +1,51 @@
Multi-Modality
==============
.. currentmodule:: vllm.multimodal
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
:class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data``
which allows you to pass in multi-modal input alongside text and token prompts.
By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model,
you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data <MultiModalRegistry.register_dummy_data>`,
as well as :meth:`MULTIMODAL_REGISTRY.register_input <MultiModalRegistry.register_input>` for each modality type to support.
.. contents::
:local:
:backlinks: none
Module Contents
+++++++++++++++
.. automodule:: vllm.multimodal
Registry
--------
.. data:: vllm.multimodal.MULTIMODAL_REGISTRY
The global :class:`MultiModalRegistry` which is used by model runners.
.. autoclass:: vllm.multimodal.MultiModalRegistry
:members:
:show-inheritance:
Base Classes
------------
.. autoclass:: vllm.multimodal.MultiModalData
:members:
:show-inheritance:
.. autoclass:: vllm.multimodal.MultiModalPlugin
:members:
:show-inheritance:
Image Classes
-------------
.. automodule:: vllm.multimodal.image
:members:
:show-inheritance:

View File

@ -88,6 +88,7 @@ Documentation
models/adding_model
models/engine_args
models/lora
models/vlm
models/performance
.. toctree::
@ -99,17 +100,18 @@ Documentation
quantization/fp8_e4m3_kvcache
.. toctree::
:maxdepth: 2
:maxdepth: 1
:caption: Developer Documentation
dev/sampling_params
dev/offline_inference/offline_index
dev/engine/engine_index
dev/kernel/paged_attention
dev/multimodal/multimodal_index
dev/dockerfile/dockerfile
.. toctree::
:maxdepth: 2
:maxdepth: 1
:caption: Community
community/meetups

View File

@ -87,6 +87,10 @@ Alongside each architecture, we include some popular models that use it.
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
- ✅︎
* - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5
- :code:`llava-hf/llava-1.5-7b-hf`\*, :code:`llava-hf/llava-1.5-13b-hf`\*, etc.
-
* - :code:`MiniCPMForCausalLM`
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.

View File

@ -0,0 +1,56 @@
.. _vlm:
Using VLMs
==========
This document shows you how to run and serve Vision Language Models (VLMs) using vLLM.
Engine Arguments
----------------
The following :ref:`engine arguments <engine_args>` are specific to VLMs:
.. argparse::
:module: vllm.engine.arg_utils
:func: _vlm_engine_args_parser
:prog: -m vllm.entrypoints.openai.api_server
:nodefaultconst:
Offline Batched Inference
-------------------------
To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` class for instantiating the engine.
.. code-block:: python
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
)
For now, we only support a single image per text prompt. To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
* ``prompt``: The prompt should have a number of ``<image>`` tokens equal to ``image_feature_size``.
* ``multi_modal_data``: This should be an instance of :class:`~vllm.multimodal.image.ImagePixelData` or :class:`~vllm.multimodal.image.ImageFeatureData`.
.. code-block:: python
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
# Load the image using PIL.Image
image = ...
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_.

View File

@ -3,33 +3,36 @@ import os
import subprocess
import torch
from PIL import Image
from vllm import LLM
from vllm.sequence import MultiModalData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them
def run_llava_pixel_values():
def run_llava_pixel_values(*, disable_image_processor: bool = False):
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
disable_image_processor=disable_image_processor,
)
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
# This should be provided by another online or offline component.
if disable_image_processor:
image = torch.load("images/stop_sign_pixel_values.pt")
else:
image = Image.open("images/stop_sign.jpg")
outputs = llm.generate({
"prompt":
prompt,
"multi_modal_data":
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
})
for o in outputs:
@ -49,15 +52,13 @@ def run_llava_image_features():
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
# This should be provided by another online or offline component.
image = torch.load("images/stop_sign_image_features.pt")
image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")
outputs = llm.generate({
"prompt":
prompt,
"multi_modal_data":
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
"prompt": prompt,
"multi_modal_data": ImageFeatureData(image),
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)

View File

@ -101,6 +101,7 @@ mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/multimodal --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml

View File

@ -12,6 +12,7 @@ aiohttp
openai
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
pillow # Required for image processing
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer

View File

@ -33,8 +33,5 @@ sentence-transformers # required for embedding
# Benchmarking
aiohttp
# Multimodal
pillow
# quantization
bitsandbytes==0.42.0

View File

@ -15,7 +15,9 @@ from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.sequence import MultiModalData, SampleLogprobs
from vllm.multimodal import MultiModalData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.sequence import SampleLogprobs
logger = init_logger(__name__)
@ -24,6 +26,7 @@ _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
# Multi modal related
# You can use `.buildkite/download-images.sh` to download the assets
_PIXEL_VALUES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
@ -89,17 +92,23 @@ def hf_images() -> List[Image.Image]:
@pytest.fixture()
def vllm_images(request) -> "torch.Tensor":
def vllm_images(request) -> List[MultiModalData]:
vision_language_config = request.getfixturevalue("model_and_config")[1]
all_images = []
if vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
filenames = _IMAGE_FEATURES_FILES
return [
ImageFeatureData(torch.load(filename))
for filename in _IMAGE_FEATURES_FILES
]
else:
filenames = _PIXEL_VALUES_FILES
for filename in filenames:
all_images.append(torch.load(filename))
return torch.concat(all_images, dim=0)
return [
ImagePixelData(Image.open(filename)) for filename in _IMAGE_FILES
]
@pytest.fixture()
def vllm_image_tensors(request) -> List[torch.Tensor]:
return [torch.load(filename) for filename in _PIXEL_VALUES_FILES]
@pytest.fixture()
@ -392,23 +401,17 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[torch.Tensor] = None,
images: Optional[List[MultiModalData]] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
prompt_inputs: List[TextPrompt] = []
for i, prompt in enumerate(prompts):
prompt = TextPrompt(prompt=prompt)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
prompt["multi_modal_data"] = MultiModalData(
type=MultiModalData.Type.IMAGE,
data=images[i:i + 1],
)
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = image
prompt_inputs.append(prompt)
req_outputs = self.model.generate(prompt_inputs,
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
outputs: List[Tuple[List[List[int]], List[str]]] = []
@ -447,7 +450,7 @@ class VllmRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional[torch.Tensor] = None,
images: Optional[List[MultiModalData]] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images)

View File

@ -1,7 +1,7 @@
import gc
from dataclasses import fields
from enum import Enum
from typing import Dict, List, Tuple
from typing import Any, Dict, List, Tuple
import pytest
import torch
@ -9,36 +9,50 @@ from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig
def iter_llava_configs(model_name: str):
image_hw_to_feature_size = {
(336, 336): 576,
}
for (h, w), f in image_hw_to_feature_size.items():
for input_type, input_shape in [
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
(VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, f, 1024)),
]:
yield (model_name,
VisionLanguageConfig(image_input_type=input_type,
image_feature_size=f,
image_token_id=32000,
image_input_shape=input_shape,
image_processor=model_name,
image_processor_revision=None))
model_and_vl_config = [
("llava-hf/llava-1.5-7b-hf",
VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_feature_size=576,
image_token_id=32000,
image_input_shape=(1, 3, 336, 336))),
("llava-hf/llava-1.5-7b-hf",
VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.IMAGE_FEATURES,
image_feature_size=576,
image_token_id=32000,
image_input_shape=(1, 576, 1024)))
*iter_llava_configs("llava-hf/llava-1.5-7b-hf"),
# Not enough memory
# *iter_llava_configs("llava-hf/llava-1.5-13b-hf"),
]
def as_dict(vision_language_config: VisionLanguageConfig) -> Dict:
def as_dict(vlm_config: VisionLanguageConfig) -> Dict[str, Any]:
"""Flatten vision language config to pure args.
Compatible with what llm entrypoint expects.
"""
result = {}
for field in fields(vision_language_config):
value = getattr(vision_language_config, field.name)
for field in fields(vlm_config):
value = getattr(vlm_config, field.name)
if isinstance(value, Enum):
result[field.name] = value.name.lower()
elif isinstance(value, tuple):
result[field.name] = ",".join([str(item) for item in value])
else:
result[field.name] = value
result["disable_image_processor"] = vlm_config.image_processor is None
return result
@ -67,18 +81,19 @@ def sanitize_vllm_output(vllm_output: Tuple[List[int], str],
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
vllm_image_prompts, vllm_images, model_and_config: tuple,
dtype: str, max_tokens: int, worker_use_ray: bool) -> None:
vllm_image_prompts, vllm_images, model_and_config, dtype: str,
max_tokens: int, worker_use_ray: bool) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the raw images as input.
For vllm runner, we provide image tensors and corresponding
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
model_id, vision_language_config = model_and_config
hf_model = hf_runner(model_id, dtype=dtype)
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
max_tokens,
@ -88,6 +103,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
vllm_model = vllm_runner(model_id,
dtype=dtype,
worker_use_ray=worker_use_ray,
enforce_eager=True,
**as_dict(vision_language_config))
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
@ -105,3 +121,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
# (Requires multiple GPUs)

View File

View File

@ -0,0 +1,98 @@
import numpy as np
import pytest
from transformers import CLIPImageProcessor
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import ImagePixelData
@pytest.mark.parametrize("dtype", ["half", "bfloat16", "float"])
def test_clip_image_processor(hf_images, dtype):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
IMAGE_HEIGHT = IMAGE_WIDTH = 33
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
assert isinstance(hf_processor, CLIPImageProcessor)
model_config = ModelConfig(
model=MODEL_NAME,
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
)
for image in hf_images:
hf_result = hf_processor.preprocess(
image,
return_tensors="np",
)
vllm_result = MULTIMODAL_REGISTRY.process_input(
ImagePixelData(image),
model_config=model_config,
vlm_config=vlm_config,
)
assert hf_result.keys() == vllm_result.keys()
for key, hf_arr in hf_result.items():
vllm_arr: np.ndarray = vllm_result[key].numpy()
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
@pytest.mark.parametrize("dtype", ["float"])
def test_image_pixel_types(hf_images, vllm_image_tensors, dtype):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
IMAGE_HEIGHT = IMAGE_WIDTH = 33
model_config = ModelConfig(
model=MODEL_NAME,
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
)
for image, tensor in zip(hf_images, vllm_image_tensors):
image_result = MULTIMODAL_REGISTRY.process_input(
ImagePixelData(image),
model_config=model_config,
vlm_config=vlm_config,
)
tensor_result = MULTIMODAL_REGISTRY.process_input(
ImagePixelData(tensor),
model_config=model_config,
vlm_config=vlm_config,
)
assert image_result.keys() == tensor_result.keys()
for key, image_arr in image_result.items():
tensor_arr: np.ndarray = tensor_result[key].numpy()
assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}"
# The examples in PR#3042 have slightly different preprocessing from
# HuggingFace's LlavaProcessor, causing the test to fail.
# assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}"

View File

@ -18,9 +18,10 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.lora.request import LoRARequest
from vllm.model_executor.utils import set_random_seed
from vllm.multimodal import MultiModalData
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, MultiModalData
from vllm.sequence import Logprob
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, random_uuid

View File

@ -0,0 +1,20 @@
import pytest
from transformers.image_processing_utils import BaseImageProcessor
from vllm.transformers_utils.image_processor import get_image_processor
IMAGE_PROCESSOR_NAMES = [
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-34b-hf",
]
@pytest.mark.parametrize("processor_name", IMAGE_PROCESSOR_NAMES)
def test_image_processor_revision(processor_name: str):
# Assume that "main" branch always exists
image_processor = get_image_processor(processor_name, revision="main")
assert isinstance(image_processor, BaseImageProcessor)
# Assume that "never" branch always does not exist
with pytest.raises(OSError, match='not a valid git identifier'):
get_image_processor(processor_name, revision="never")

View File

@ -1094,10 +1094,12 @@ class VisionLanguageConfig:
# worst case scenario (biggest supported resolution).
image_input_shape: tuple
image_feature_size: int
# The image processor to load from HuggingFace
image_processor: Optional[str]
image_processor_revision: Optional[str]
@classmethod
def get_image_input_enum_type(
cls, value: str) -> "VisionLanguageConfig.ImageInputType":
def get_image_input_enum_type(cls, value: str) -> ImageInputType:
"""Get the image input type from a string."""
try:
return cls.ImageInputType[value.upper()]

View File

@ -1,6 +1,7 @@
import argparse
import dataclasses
import json
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
@ -80,6 +81,10 @@ class EngineArgs:
image_token_id: Optional[int] = None
image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None
image_processor: Optional[str] = None
image_processor_revision: Optional[str] = None
disable_image_processor: bool = False
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
@ -98,6 +103,53 @@ class EngineArgs:
if self.tokenizer is None:
self.tokenizer = self.model
@staticmethod
def add_cli_args_for_vlm(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--image-input-type',
type=nullable_str,
default=None,
choices=[
t.name.lower()
for t in VisionLanguageConfig.ImageInputType
],
help=('The image input type passed into vLLM.'))
parser.add_argument('--image-token-id',
type=int,
default=None,
help=('Input id for image token.'))
parser.add_argument(
'--image-input-shape',
type=nullable_str,
default=None,
help=('The biggest image input shape (worst for memory footprint) '
'given an input type. Only used for vLLM\'s profile_run.'))
parser.add_argument(
'--image-feature-size',
type=int,
default=None,
help=('The image feature size along the context dimension.'))
parser.add_argument(
'--image-processor',
type=str,
default=EngineArgs.image_processor,
help='Name or path of the huggingface image processor to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
'--image-processor-revision',
type=str,
default=None,
help='Revision of the huggingface image processor version to use. '
'It can be a branch name, a tag name, or a commit id. '
'If unspecified, will use the default version.')
parser.add_argument(
'--disable-image-processor',
action='store_true',
help='Disables the use of image processor, even if one is defined '
'for the model on huggingface.')
return parser
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
@ -113,7 +165,8 @@ class EngineArgs:
'--tokenizer',
type=nullable_str,
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use.')
help='Name or path of the huggingface tokenizer to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
'--skip-tokenizer-init',
action='store_true',
@ -136,9 +189,9 @@ class EngineArgs:
'--tokenizer-revision',
type=nullable_str,
default=None,
help='The specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
help='Revision of the huggingface tokenizer to use. '
'It can be a branch name, a tag name, or a commit id. '
'If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-mode',
type=str,
@ -445,31 +498,10 @@ class EngineArgs:
default=EngineArgs.device,
choices=["auto", "cuda", "neuron", "cpu"],
help='Device type for vLLM execution.')
# Related to Vision-language models such as llava
parser.add_argument(
'--image-input-type',
type=nullable_str,
default=None,
choices=[
t.name.lower() for t in VisionLanguageConfig.ImageInputType
],
help=('The image input type passed into vLLM. '
'Should be one of "pixel_values" or "image_features".'))
parser.add_argument('--image-token-id',
type=int,
default=None,
help=('Input id for image token.'))
parser.add_argument(
'--image-input-shape',
type=nullable_str,
default=None,
help=('The biggest image input shape (worst for memory footprint) '
'given an input type. Only used for vLLM\'s profile_run.'))
parser.add_argument(
'--image-feature-size',
type=int,
default=None,
help=('The image feature size along the context dimension.'))
parser = EngineArgs.add_cli_args_for_vlm(parser)
parser.add_argument(
'--scheduler-delay-factor',
type=float,
@ -488,7 +520,6 @@ class EngineArgs:
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
@ -666,12 +697,27 @@ class EngineArgs:
raise ValueError(
'Specify `image_token_id`, `image_input_shape` and '
'`image_feature_size` together with `image_input_type`.')
if self.image_processor is None:
self.image_processor = self.model
if self.disable_image_processor:
if self.image_processor != self.model:
warnings.warn(
"You've specified an image processor "
f"({self.image_processor}) but also disabled "
"it via `--disable-image-processor`.",
stacklevel=2)
self.image_processor = None
vision_language_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.
get_image_input_enum_type(self.image_input_type),
image_token_id=self.image_token_id,
image_input_shape=str_to_int_tuple(self.image_input_shape),
image_feature_size=self.image_feature_size,
image_processor=self.image_processor,
image_processor_revision=self.image_processor_revision,
)
else:
vision_language_config = None
@ -734,3 +780,7 @@ def _engine_args_parser():
def _async_engine_args_parser():
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
async_args_only=True)
def _vlm_engine_args_parser():
return EngineArgs.add_cli_args_for_vlm(argparse.ArgumentParser())

View File

@ -14,7 +14,6 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs
@ -164,7 +163,6 @@ class LLM:
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@ -177,7 +175,6 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@ -191,7 +188,6 @@ class LLM:
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@ -205,7 +201,6 @@ class LLM:
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@ -217,7 +212,6 @@ class LLM:
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@ -236,7 +230,6 @@ class LLM:
@deprecate_kwargs("prompts",
"prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
@ -249,7 +242,6 @@ class LLM:
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
@ -281,11 +273,10 @@ class LLM:
"LLM.generate() is only supported for generation models "
"(XForCausalLM).")
if prompt_token_ids is not None or multi_modal_data is not None:
if prompt_token_ids is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
)
else:
inputs = cast(
@ -314,7 +305,6 @@ class LLM:
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@ -327,7 +317,6 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@ -341,7 +330,6 @@ class LLM:
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@ -355,7 +343,6 @@ class LLM:
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@ -367,7 +354,6 @@ class LLM:
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@ -386,7 +372,6 @@ class LLM:
@deprecate_kwargs("prompts",
"prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
@ -399,7 +384,6 @@ class LLM:
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts.
@ -430,11 +414,10 @@ class LLM:
"LLM.encode() is only supported for embedding models (XModel)."
)
if prompt_token_ids is not None or multi_modal_data is not None:
if prompt_token_ids is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
)
else:
inputs = cast(
@ -459,7 +442,6 @@ class LLM:
self,
prompts: Optional[Union[str, List[str]]],
prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
multi_modal_data: Optional[MultiModalData],
):
# skip_tokenizer_init is now checked in engine
@ -499,9 +481,6 @@ class LLM:
else:
raise AssertionError
if multi_modal_data is not None:
item["multi_modal_data"] = multi_modal_data
inputs.append(item)
return inputs

View File

@ -17,6 +17,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import get_dummy_image_data
from vllm.sequence import SamplerOutput
from .vlm_base import VisionLanguageModelBase
@ -82,6 +84,9 @@ class LlavaImageFeatureInputs(TypedDict):
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
@MULTIMODAL_REGISTRY.register_image_feature_input()
@MULTIMODAL_REGISTRY.register_image_pixel_input()
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
class LlavaForConditionalGeneration(VisionLanguageModelBase):
def __init__(self,
@ -131,30 +136,41 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
return data
def _parse_and_validate_image_input(
self, data: object) -> Optional[LlavaImageInputs]:
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_features = kwargs.pop("image_features", None)
expected_input_type = self.vision_language_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if data is None:
if expected_input_type == ImageInputType.PIXEL_VALUES:
if image_features is not None:
raise ValueError(
"Expected pixel values but got image features")
if pixel_values is None:
return None
if expected_input_type == ImageInputType.PIXEL_VALUES:
if not isinstance(data, torch.Tensor):
raise TypeError("Image pixel vector should be a tensor, "
f"but received type: {type(data)}")
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values")
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_data(data),
data=self._validate_image_data(pixel_values),
)
elif expected_input_type == ImageInputType.IMAGE_FEATURES:
if not isinstance(data, torch.Tensor):
raise TypeError("Image feature vector should be a tensor, "
f"but received type: {type(data)}")
if expected_input_type == ImageInputType.IMAGE_FEATURES:
if pixel_values is not None:
raise ValueError(
"Expected image features but got pixel values")
if image_features is None:
return None
if not isinstance(image_features, torch.Tensor):
raise ValueError("Incorrect type of image features")
return LlavaImageFeatureInputs(
type="image_features",
data=self._validate_image_data(data),
data=self._validate_image_data(image_features),
)
return None
@ -201,12 +217,14 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
return self.multi_modal_projector(image_features)
def forward(self,
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
image_input: Optional[torch.Tensor] = None) -> SamplerOutput:
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for Llava 1.5.
One key thing to understand is the `input_ids` already accounts for the
@ -239,14 +257,15 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
image_input: A batch of image inputs.
For PIXEL_VALUES, expecting [1, 3, 336, 336].
For IMAGE_FEATURES, expecting [1, 576, 1024].
pixel_values: For PIXEL_VALUES, expects a batch with shape
[1, 3, 336, 336].
image_features: For IMAGE_FEATURES, expects a batch with shape
[1, 576, 1024].
"""
parsed_image_input = self._parse_and_validate_image_input(image_input)
image_input = self._parse_and_validate_image_input(**kwargs)
if parsed_image_input is not None:
vision_embeddings = self._process_image_input(parsed_image_input)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = _merge_vision_embeddings(

View File

@ -0,0 +1,7 @@
from .base import MultiModalData, MultiModalPlugin
from .registry import MULTIMODAL_REGISTRY, MultiModalRegistry
__all__ = [
"MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY",
"MultiModalRegistry"
]

126
vllm/multimodal/base.py Normal file
View File

@ -0,0 +1,126 @@
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
TypeVar)
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger
if TYPE_CHECKING:
import torch
from torch import nn
logger = init_logger(__name__)
class MultiModalData:
"""
Base class that contains multi-modal data.
To add a new modality, add a new file under ``multimodal`` directory.
In this new file, subclass :class:`~MultiModalData` and
:class:`~MultiModalPlugin`.
Finally, register the new plugin to
:const:`vllm.multimodal.MULTIMODAL_REGISTRY`.
This enables models to call :meth:`MultiModalRegistry.register_input` for
the new modality.
"""
pass
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
Dict[str, "torch.Tensor"]]
"""Return a dictionary to be passed as keyword arguments to
:meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers."""
class MultiModalPlugin(ABC, Generic[D]):
"""
Base class that defines data processing logic for a specific modality.
In particular, we adopt a registry pattern to dispatch data processing
according to the model being used (considering that different models may
process the same data differently). This registry is in turn used by
:class:`~MultiModalRegistry` which acts at a higher level
(i.e., the modality of the data).
"""
@classmethod
def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]:
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
return get_model_architecture(model_config)[0]
def __init__(self) -> None:
self._input_processors: Dict[Type["nn.Module"],
MultiModalInputProcessor[D]] = {}
@abstractmethod
def get_data_type(self) -> Type[D]:
"""
Get the modality (subclass of :class:`~MultiModalData`) served by
this plugin.
"""
raise NotImplementedError
@abstractmethod
def _default_input_processor(
self, data: D, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
"""Return a dictionary to be passed as keyword arguments to
:meth:`torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers.
"""
raise NotImplementedError
def register_input_processor(self,
processor: Optional[
MultiModalInputProcessor[D]] = None):
"""
Register an input processor to a model class.
When the model receives input data that matches the modality served by
this plugin (see :meth:`get_data_type`), the provided input processor is
applied to preprocess the data. If `None` is provided, then the default
input processor is applied instead.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._input_processors:
logger.warning(
"Model class %s already has an input processor "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._input_processors[model_cls] = processor \
or self._default_input_processor
return model_cls
return wrapper
def process_input(
self, data: D, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
"""
Apply an input processor to a :class:`~MultiModalData` instance passed
to the model.
The model is identified by ``model_config``. ``vlm_config`` is
for compatibility purposes and may be merged into ``model_config``
in the near future.
"""
model_cls = self.get_model_cls(model_config)
processor = self._input_processors.get(model_cls)
if processor is None:
raise KeyError(f"No input processor in {self} is registered for "
f"model class {model_cls.__name__}.")
return processor(data, model_config, vlm_config)

141
vllm/multimodal/image.py Normal file
View File

@ -0,0 +1,141 @@
from typing import Dict, Tuple, Type, Union
import torch
from PIL import Image
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger
from vllm.sequence import SequenceData
from vllm.transformers_utils.image_processor import cached_get_image_processor
from .base import MultiModalData, MultiModalPlugin
logger = init_logger(__name__)
def _get_dummy_seq_data(seq_len: int,
vlm_config: VisionLanguageConfig) -> SequenceData:
# NOTE: We assume that <image> token is repeated `image_feature_size` times
# and then concatenated with the text prompt
# TODO: Enable other ways of inserting the image into the prompt
token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size
token_ids += [0] * (seq_len - vlm_config.image_feature_size)
return SequenceData(token_ids)
def _get_dummy_values(vlm_config: VisionLanguageConfig) -> torch.Tensor:
if vlm_config.image_processor is None:
values_dtype = torch.float16
else:
values_dtype = torch.uint8
return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype)
def get_dummy_image_data(
seq_len: int,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Tuple[SequenceData, MultiModalData]:
"""Standard dummy data factory for image data (to be used in
:meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`)."""
seq_data = _get_dummy_seq_data(seq_len, vlm_config)
values = _get_dummy_values(vlm_config)
config_input_type = vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
fake_mm_data: MultiModalData
if config_input_type == ImageInputType.PIXEL_VALUES:
fake_mm_data = ImagePixelData(values)
elif config_input_type == ImageInputType.IMAGE_FEATURES:
fake_mm_data = ImageFeatureData(values)
else:
raise NotImplementedError
return seq_data, fake_mm_data
class ImagePixelData(MultiModalData):
"""
The pixel data of an image. Can be one of:
- :class:``PIL.Image``: An image object. Requires that a HuggingFace
processor is available to the model.
- :class:``torch.Tensor``: The raw pixel data which is passed to the model
without additional pre-processing.
"""
def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None:
if isinstance(image, Image.Image):
# So that this class can be created inside the Image context manager
image.load()
self.image = image
class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
def get_data_type(self) -> Type[ImagePixelData]:
return ImagePixelData
def _get_hf_image_processor(self, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
if vlm_config is None or vlm_config.image_processor is None:
return None
return cached_get_image_processor(
vlm_config.image_processor,
trust_remote_code=model_config.trust_remote_code,
revision=vlm_config.image_processor_revision,
)
def _default_input_processor(
self, data: ImagePixelData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
image = data.image
image_processor = self._get_hf_image_processor(model_config,
vlm_config)
if isinstance(image, Image.Image):
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available"
"to process the image object")
try:
return image_processor.preprocess(image, return_tensors="pt") \
.to(model_config.dtype).data
except Exception:
logger.error("Failed to process image (%s)", image)
raise
elif isinstance(image, torch.Tensor):
pixel_values = image.to(model_config.dtype)
return {"pixel_values": pixel_values}
raise TypeError(f"Invalid image type: {type(image)}")
class ImageFeatureData(MultiModalData):
"""
The feature vector of an image, passed directly to the model.
This should be the output of the vision tower.
"""
def __init__(self, image_features: torch.Tensor) -> None:
self.image_features = image_features
class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):
def get_data_type(self) -> Type[ImageFeatureData]:
return ImageFeatureData
def _default_input_processor(
self, data: ImageFeatureData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
image_features = data.image_features.to(model_config.dtype)
return {"image_features": image_features}

156
vllm/multimodal/registry.py Normal file
View File

@ -0,0 +1,156 @@
import functools
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence,
Tuple, Type, TypeVar)
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger
from .base import MultiModalData, MultiModalPlugin
from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
ImagePixelPlugin)
if TYPE_CHECKING:
import torch
from torch import nn
from vllm.sequence import SequenceData
logger = init_logger(__name__)
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
Dict[str, "torch.Tensor"]]
MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig],
Tuple["SequenceData", MultiModalData]]
class MultiModalRegistry:
"""
This registry is used by model runners to dispatch data processing
according to its modality and the target model.
"""
DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
def __init__(self,
*,
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS
) -> None:
self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
self._dummy_factories_by_model_type: Dict[Type["nn.Module"],
MultiModalDummyFactory] = {}
def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
data_type = plugin.get_data_type()
if data_type in self._plugins_by_data_type:
logger.warning(
"A plugin is already registered for data type %s, "
"and will be overwritten by the new plugin %s.", data_type,
plugin)
self._plugins_by_data_type[data_type] = plugin
def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]):
for typ in data_type.mro():
plugin = self._plugins_by_data_type.get(typ)
if plugin is not None:
return plugin
msg = f"Unknown multi-modal data type: {data_type}"
raise NotImplementedError(msg)
def register_dummy_data(self, factory: MultiModalDummyFactory):
"""
Register a dummy data factory to a model class.
During memory profiling, the provided function is invoked to create
dummy data to be inputted into the model. The modality and shape of
the dummy data should be an upper bound of what the model would receive
at inference time.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_factories_by_model_type:
logger.warning(
"Model class %s already has dummy data "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._dummy_factories_by_model_type[model_cls] = factory
return model_cls
return wrapper
def dummy_data_for_profiling(self, seq_len: int, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
"""Create dummy data for memory profiling."""
model_cls = MultiModalPlugin.get_model_cls(model_config)
dummy_factory = self._dummy_factories_by_model_type.get(model_cls)
if dummy_factory is None:
msg = f"No dummy data defined for model class: {model_cls}"
raise NotImplementedError(msg)
return dummy_factory(seq_len, model_config, vlm_config)
def register_input(
self,
data_type: Type[D],
processor: Optional[MultiModalInputProcessor[D]] = None):
"""
Register an input processor for a specific modality to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details.
"""
return self._get_plugin_for_data_type(data_type) \
.register_input_processor(processor)
def register_image_pixel_input(
self,
processor: Optional[
MultiModalInputProcessor[ImagePixelData]] = None):
"""
Register an input processor for image pixel data to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details.
"""
return self.register_input(ImagePixelData, processor)
def register_image_feature_input(
self,
processor: Optional[
MultiModalInputProcessor[ImageFeatureData]] = None):
"""
Register an input processor for image feature data to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details.
"""
return self.register_input(ImageFeatureData, processor)
def process_input(self, data: MultiModalData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
"""
Apply an input processor to a :class:`~MultiModalData` instance passed
to the model.
See :meth:`MultiModalPlugin.process_input` for more details.
"""
return self._get_plugin_for_data_type(type(data)) \
.process_input(data, model_config, vlm_config)
def create_input_processor(self, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
"""
Create an input processor (see :meth:`process_input`) for a
specific model.
"""
return functools.partial(self.process_input,
model_config=model_config,
vlm_config=vlm_config)
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""The global :class:`~MultiModalRegistry` which is used by model runners."""

View File

@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from vllm.block import LogicalTokenBlock
from vllm.inputs import LLMInputs
from vllm.lora.request import LoRARequest
@ -12,8 +14,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
if TYPE_CHECKING:
import torch
from vllm.multimodal import MultiModalData
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@ -398,25 +399,6 @@ class SequenceGroupState:
generator: Optional = None # type: ignore
class MultiModalData:
"""Multi modal request.
Args:
type: The data type.
data: The actual data.
The required shape and semantic meaning of it depends on the vision
language config of the hosted model.
See `VisionLanguageConfig` in `config.py`.
"""
class Type(enum.Enum):
IMAGE = enum.auto()
def __init__(self, type: Type, data: "torch.Tensor"):
self.type = type
self.data = data
class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
@ -473,7 +455,7 @@ class SequenceGroup:
return next(iter(self.seqs_dict.values())).prompt_token_ids
@property
def multi_modal_data(self) -> Optional[MultiModalData]:
def multi_modal_data(self) -> Optional["MultiModalData"]:
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).multi_modal_data
@ -655,7 +637,7 @@ class SequenceGroupMetadata:
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional[MultiModalData] = None,
multi_modal_data: Optional["MultiModalData"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
) -> None:
@ -798,13 +780,13 @@ class SamplerOutput:
outputs: List[CompletionSequenceGroupOutput]
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional["torch.Tensor"] = None
sampled_token_probs: Optional[torch.Tensor] = None
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional["torch.Tensor"] = None
sampled_token_ids: Optional[torch.Tensor] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

View File

@ -0,0 +1,45 @@
from functools import lru_cache
from typing import Optional
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor
from vllm.logger import init_logger
logger = init_logger(__name__)
def get_image_processor(
processor_name: str,
*args,
trust_remote_code: bool = False,
revision: Optional[str] = None,
**kwargs,
) -> BaseImageProcessor:
"""Gets an image processor for the given model name via HuggingFace."""
try:
processor: BaseImageProcessor = AutoImageProcessor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
if not trust_remote_code:
err_msg = (
"Failed to load the image processor. If the image processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
return processor
cached_get_image_processor = lru_cache(get_image_processor)

View File

@ -1,4 +1,5 @@
from typing import List, Optional, Tuple
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
@ -11,6 +12,7 @@ from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
@ -63,6 +65,16 @@ class CPUModelRunner:
self.block_size,
)
# Create processor for multi-modal data
if self.vision_language_config is not None:
self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
.create_input_processor(
self.model_config,
self.vision_language_config,
)
else:
self.multi_modal_input_processor = None
# Lazy initialization.
self.model: nn.Module # Set after init_Model
@ -80,14 +92,15 @@ class CPUModelRunner:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Optional[torch.Tensor]]:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], Dict[
str, torch.Tensor]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = []
multi_modal_kwargs_list: Dict[str,
List[torch.Tensor]] = defaultdict(list)
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
@ -108,9 +121,17 @@ class CPUModelRunner:
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len)))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)
mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None:
# Process multi-modal data
if self.multi_modal_input_processor is None:
raise ValueError(
"Multi-modal inputs are only supported by "
"vision language models.")
mm_kwargs = self.multi_modal_input_processor(mm_data)
for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v)
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
@ -134,14 +155,10 @@ class CPUModelRunner:
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None
multi_modal_kwargs = {
k: torch.cat(v, dim=0).to(self.device)
for k, v in multi_modal_kwargs_list.items()
}
num_prompt_tokens = len(input_tokens)
@ -167,7 +184,7 @@ class CPUModelRunner:
slot_mapping=slot_mapping,
)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input)
multi_modal_kwargs)
def _prepare_decode(
self,
@ -257,8 +274,8 @@ class CPUModelRunner:
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Optional[torch.Tensor]]:
multi_modal_input = None
Optional[Dict[str, torch.Tensor]]]:
multi_modal_kwargs = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
@ -266,7 +283,7 @@ class CPUModelRunner:
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
@ -307,7 +324,7 @@ class CPUModelRunner:
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, multi_modal_input)
sampling_metadata, multi_modal_kwargs)
@torch.inference_mode()
def execute_model(

View File

@ -90,7 +90,7 @@ class EmbeddingModelRunner(ModelRunner):
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
Set[LoRARequest], LoRAMapping, torch.Tensor]:
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
# Prepare input tensors.
@ -102,7 +102,7 @@ class EmbeddingModelRunner(ModelRunner):
_,
lora_mapping,
lora_requests,
multi_modal_input,
multi_modal_kwargs,
slot_mapping,
num_prefill_tokens,
num_decode_tokens,
@ -117,7 +117,7 @@ class EmbeddingModelRunner(ModelRunner):
"input_positions": input_positions,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_input": multi_modal_input,
"multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
@ -132,7 +132,7 @@ class EmbeddingModelRunner(ModelRunner):
input_positions = metadata_dict.pop("input_positions")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input")
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
@ -143,7 +143,7 @@ class EmbeddingModelRunner(ModelRunner):
prompt_lens=None)
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
lora_requests, lora_mapping, multi_modal_input)
lora_requests, lora_mapping, multi_modal_kwargs)
def _prepare_pooling(
self,

View File

@ -1,5 +1,6 @@
import time
import warnings
from collections import defaultdict
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
import numpy as np
@ -18,9 +19,9 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad)
@ -44,7 +45,7 @@ class ModelInput(NamedTuple):
query_lens: List[int]
lora_mapping: Optional[LoRAMapping]
lora_requests: Set[LoRARequest]
multi_modal_input: Optional[torch.Tensor]
multi_modal_kwargs: Dict[str, torch.Tensor]
slot_mapping: torch.Tensor
num_prefill_tokens: int
num_decode_tokens: int
@ -60,7 +61,7 @@ class ModelInput(NamedTuple):
query_lens=[],
lora_mapping=None,
lora_requests=set(),
multi_modal_input=None,
multi_modal_kwargs={},
slot_mapping=torch.empty(0, device=device),
num_prefill_tokens=0,
num_decode_tokens=0,
@ -122,6 +123,16 @@ class ModelRunner:
self.block_size,
)
# Create processor for multi-modal data
if self.vision_language_config is not None:
self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
.create_input_processor(
self.model_config,
self.vision_language_config,
)
else:
self.multi_modal_input_processor = None
# Lazy initialization
self.model: nn.Module # Set after load_model
# Set if the backend is flashinfer.
@ -242,7 +253,8 @@ class ModelRunner:
context_lens: List[int] = []
query_lens: List[int] = []
block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = []
multi_modal_kwargs_list: Dict[str,
List[torch.Tensor]] = defaultdict(list)
decode_only = True
num_prefills = 0
num_prefill_tokens = 0
@ -417,9 +429,17 @@ class ModelRunner:
and seq_group_metadata.sampling_params.prompt_logprobs
else 1))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)
mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None:
# Process multi-modal data
if self.multi_modal_input_processor is None:
raise ValueError(
"Multi-modal inputs are only supported by "
"vision language models.")
mm_kwargs = self.multi_modal_input_processor(mm_data)
for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v)
if _is_block_tables_empty(seq_group_metadata.block_tables):
# During memory profiling, the block tables are not
@ -508,16 +528,6 @@ class ModelRunner:
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
@ -614,6 +624,11 @@ class ModelRunner:
else:
lora_mapping = None
multi_modal_kwargs = {
k: torch.cat(v, dim=0).to(self.device)
for k, v in multi_modal_kwargs_list.items()
}
return ModelInput(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
@ -622,7 +637,7 @@ class ModelRunner:
query_lens=query_lens,
lora_mapping=lora_mapping,
lora_requests=lora_requests,
multi_modal_input=multi_modal_input,
multi_modal_kwargs=multi_modal_kwargs,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
@ -633,7 +648,7 @@ class ModelRunner:
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[LoRARequest], LoRAMapping, torch.Tensor]:
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
# Prepare input tensors.
@ -645,7 +660,7 @@ class ModelRunner:
query_lens,
lora_mapping,
lora_requests,
multi_modal_input,
multi_modal_kwargs,
slot_mapping,
num_prefill_tokens,
num_decode_tokens,
@ -662,7 +677,7 @@ class ModelRunner:
sampling_metadata.selected_token_indices,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_input": multi_modal_input,
"multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
@ -679,7 +694,7 @@ class ModelRunner:
"selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input")
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
@ -694,7 +709,7 @@ class ModelRunner:
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping,
multi_modal_input)
multi_modal_kwargs)
@torch.inference_mode()
def execute_model(
@ -703,7 +718,7 @@ class ModelRunner:
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_mapping, multi_modal_input
lora_requests, lora_mapping, multi_modal_kwargs
) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config:
@ -717,15 +732,14 @@ class ModelRunner:
model_executable = self.graph_runners[graph_batch_size]
else:
model_executable = self.model
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
hidden_states = model_executable(
input_ids=input_tokens,
positions=input_positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
**multi_modal_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
@ -781,16 +795,24 @@ class ModelRunner:
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
if self.vision_language_config:
model_config = self.model_config
vlm_config = self.vision_language_config
if vlm_config:
max_num_seqs = min(
max_num_seqs,
int(max_num_batched_tokens /
self.vision_language_config.image_feature_size))
int(max_num_batched_tokens / vlm_config.image_feature_size))
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
seq_data, fake_multi_modal_input = _prepare_fake_inputs(
seq_len, self.vision_language_config)
if vlm_config is None:
seq_data = SequenceData([0] * seq_len)
dummy_multi_modal_data = None
else:
seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \
.dummy_data_for_profiling(seq_len, model_config, vlm_config)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
@ -799,7 +821,7 @@ class ModelRunner:
block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=fake_multi_modal_input,
multi_modal_data=dummy_multi_modal_data,
)
seqs.append(seq)
@ -1034,24 +1056,6 @@ def _get_graph_batch_size(batch_size: int) -> int:
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
def _prepare_fake_inputs(
seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
"""Prepare fake inputs for profile run."""
if vision_language_config:
prompt_tokens = [
vision_language_config.image_token_id
] * vision_language_config.image_feature_size + [0] * (
seq_len - vision_language_config.image_feature_size)
fake_image_input = MultiModalData(
type=MultiModalData.Type.IMAGE,
data=torch.zeros(vision_language_config.image_input_shape,
dtype=torch.float16))
else:
prompt_tokens = [0] * seq_len
fake_image_input = None
return SequenceData(prompt_tokens), fake_image_input
def _is_block_tables_empty(block_tables: Union[None, Dict]):
"""
Check if block_tables is None or a dictionary with all None values.