[Core] Support image processor (#4197)
This commit is contained in:
parent
dfbe60dc62
commit
7a64d24aad
1
.github/workflows/mypy.yaml
vendored
1
.github/workflows/mypy.yaml
vendored
@ -37,6 +37,7 @@ jobs:
|
||||
mypy vllm/distributed --config-file pyproject.toml
|
||||
mypy vllm/entrypoints --config-file pyproject.toml
|
||||
mypy vllm/executor --config-file pyproject.toml
|
||||
mypy vllm/multimodal --config-file pyproject.toml
|
||||
mypy vllm/usage --config-file pyproject.toml
|
||||
mypy vllm/*.py --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||
|
@ -90,6 +90,7 @@ autodoc_mock_imports = [
|
||||
"sentencepiece",
|
||||
"vllm.cuda_utils",
|
||||
"vllm._C",
|
||||
"PIL",
|
||||
"numpy",
|
||||
"tqdm",
|
||||
"tensorizer",
|
||||
@ -116,12 +117,13 @@ class MockedClassDocumenter(autodoc.ClassDocumenter):
|
||||
autodoc.ClassDocumenter = MockedClassDocumenter
|
||||
|
||||
intersphinx_mapping = {
|
||||
'python': ('https://docs.python.org/3', None),
|
||||
'typing_extensions':
|
||||
('https://typing-extensions.readthedocs.io/en/latest', None),
|
||||
'numpy': ('https://numpy.org/doc/stable', None),
|
||||
'torch': ('https://pytorch.org/docs/stable', None),
|
||||
'psutil': ('https://psutil.readthedocs.io/en/stable', None),
|
||||
"python": ("https://docs.python.org/3", None),
|
||||
"typing_extensions":
|
||||
("https://typing-extensions.readthedocs.io/en/latest", None),
|
||||
"pillow": ("https://pillow.readthedocs.io/en/stable", None),
|
||||
"numpy": ("https://numpy.org/doc/stable", None),
|
||||
"torch": ("https://pytorch.org/docs/stable", None),
|
||||
"psutil": ("https://psutil.readthedocs.io/en/stable", None),
|
||||
}
|
||||
|
||||
autodoc_preserve_defaults = True
|
||||
|
51
docs/source/dev/multimodal/multimodal_index.rst
Normal file
51
docs/source/dev/multimodal/multimodal_index.rst
Normal file
@ -0,0 +1,51 @@
|
||||
Multi-Modality
|
||||
==============
|
||||
|
||||
.. currentmodule:: vllm.multimodal
|
||||
|
||||
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
||||
|
||||
:class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data``
|
||||
which allows you to pass in multi-modal input alongside text and token prompts.
|
||||
|
||||
By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model,
|
||||
you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data <MultiModalRegistry.register_dummy_data>`,
|
||||
as well as :meth:`MULTIMODAL_REGISTRY.register_input <MultiModalRegistry.register_input>` for each modality type to support.
|
||||
|
||||
.. contents::
|
||||
:local:
|
||||
:backlinks: none
|
||||
|
||||
Module Contents
|
||||
+++++++++++++++
|
||||
|
||||
.. automodule:: vllm.multimodal
|
||||
|
||||
Registry
|
||||
--------
|
||||
|
||||
.. data:: vllm.multimodal.MULTIMODAL_REGISTRY
|
||||
|
||||
The global :class:`MultiModalRegistry` which is used by model runners.
|
||||
|
||||
.. autoclass:: vllm.multimodal.MultiModalRegistry
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
Base Classes
|
||||
------------
|
||||
|
||||
.. autoclass:: vllm.multimodal.MultiModalData
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: vllm.multimodal.MultiModalPlugin
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
Image Classes
|
||||
-------------
|
||||
|
||||
.. automodule:: vllm.multimodal.image
|
||||
:members:
|
||||
:show-inheritance:
|
@ -88,6 +88,7 @@ Documentation
|
||||
models/adding_model
|
||||
models/engine_args
|
||||
models/lora
|
||||
models/vlm
|
||||
models/performance
|
||||
|
||||
.. toctree::
|
||||
@ -99,17 +100,18 @@ Documentation
|
||||
quantization/fp8_e4m3_kvcache
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:maxdepth: 1
|
||||
:caption: Developer Documentation
|
||||
|
||||
dev/sampling_params
|
||||
dev/offline_inference/offline_index
|
||||
dev/engine/engine_index
|
||||
dev/kernel/paged_attention
|
||||
dev/multimodal/multimodal_index
|
||||
dev/dockerfile/dockerfile
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:maxdepth: 1
|
||||
:caption: Community
|
||||
|
||||
community/meetups
|
||||
|
@ -87,6 +87,10 @@ Alongside each architecture, we include some popular models that use it.
|
||||
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
|
||||
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
|
||||
- ✅︎
|
||||
* - :code:`LlavaForConditionalGeneration`
|
||||
- LLaVA-1.5
|
||||
- :code:`llava-hf/llava-1.5-7b-hf`\*, :code:`llava-hf/llava-1.5-13b-hf`\*, etc.
|
||||
-
|
||||
* - :code:`MiniCPMForCausalLM`
|
||||
- MiniCPM
|
||||
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
|
||||
|
56
docs/source/models/vlm.rst
Normal file
56
docs/source/models/vlm.rst
Normal file
@ -0,0 +1,56 @@
|
||||
.. _vlm:
|
||||
|
||||
Using VLMs
|
||||
==========
|
||||
|
||||
This document shows you how to run and serve Vision Language Models (VLMs) using vLLM.
|
||||
|
||||
Engine Arguments
|
||||
----------------
|
||||
|
||||
The following :ref:`engine arguments <engine_args>` are specific to VLMs:
|
||||
|
||||
.. argparse::
|
||||
:module: vllm.engine.arg_utils
|
||||
:func: _vlm_engine_args_parser
|
||||
:prog: -m vllm.entrypoints.openai.api_server
|
||||
:nodefaultconst:
|
||||
|
||||
Offline Batched Inference
|
||||
-------------------------
|
||||
|
||||
To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` class for instantiating the engine.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
llm = LLM(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
image_input_type="pixel_values",
|
||||
image_token_id=32000,
|
||||
image_input_shape="1,3,336,336",
|
||||
image_feature_size=576,
|
||||
)
|
||||
|
||||
For now, we only support a single image per text prompt. To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
|
||||
|
||||
* ``prompt``: The prompt should have a number of ``<image>`` tokens equal to ``image_feature_size``.
|
||||
* ``multi_modal_data``: This should be an instance of :class:`~vllm.multimodal.image.ImagePixelData` or :class:`~vllm.multimodal.image.ImageFeatureData`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt = "<image>" * 576 + (
|
||||
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||
|
||||
# Load the image using PIL.Image
|
||||
image = ...
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": ImagePixelData(image),
|
||||
})
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_.
|
@ -3,33 +3,36 @@ import os
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.sequence import MultiModalData
|
||||
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
|
||||
|
||||
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
|
||||
# You can use `.buildkite/download-images.sh` to download them
|
||||
|
||||
|
||||
def run_llava_pixel_values():
|
||||
def run_llava_pixel_values(*, disable_image_processor: bool = False):
|
||||
llm = LLM(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
image_input_type="pixel_values",
|
||||
image_token_id=32000,
|
||||
image_input_shape="1,3,336,336",
|
||||
image_feature_size=576,
|
||||
disable_image_processor=disable_image_processor,
|
||||
)
|
||||
|
||||
prompt = "<image>" * 576 + (
|
||||
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||
|
||||
# This should be provided by another online or offline component.
|
||||
image = torch.load("images/stop_sign_pixel_values.pt")
|
||||
if disable_image_processor:
|
||||
image = torch.load("images/stop_sign_pixel_values.pt")
|
||||
else:
|
||||
image = Image.open("images/stop_sign.jpg")
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt":
|
||||
prompt,
|
||||
"multi_modal_data":
|
||||
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": ImagePixelData(image),
|
||||
})
|
||||
|
||||
for o in outputs:
|
||||
@ -49,15 +52,13 @@ def run_llava_image_features():
|
||||
prompt = "<image>" * 576 + (
|
||||
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||
|
||||
# This should be provided by another online or offline component.
|
||||
image = torch.load("images/stop_sign_image_features.pt")
|
||||
image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt":
|
||||
prompt,
|
||||
"multi_modal_data":
|
||||
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": ImageFeatureData(image),
|
||||
})
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
@ -101,6 +101,7 @@ mypy vllm/core --config-file pyproject.toml
|
||||
mypy vllm/distributed --config-file pyproject.toml
|
||||
mypy vllm/entrypoints --config-file pyproject.toml
|
||||
mypy vllm/executor --config-file pyproject.toml
|
||||
mypy vllm/multimodal --config-file pyproject.toml
|
||||
mypy vllm/usage --config-file pyproject.toml
|
||||
mypy vllm/*.py --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||
|
@ -12,6 +12,7 @@ aiohttp
|
||||
openai
|
||||
uvicorn[standard]
|
||||
pydantic >= 2.0 # Required for OpenAI server.
|
||||
pillow # Required for image processing
|
||||
prometheus_client >= 0.18.0
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
|
@ -33,8 +33,5 @@ sentence-transformers # required for embedding
|
||||
# Benchmarking
|
||||
aiohttp
|
||||
|
||||
# Multimodal
|
||||
pillow
|
||||
|
||||
# quantization
|
||||
bitsandbytes==0.42.0
|
||||
|
@ -15,7 +15,9 @@ from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
||||
from vllm.distributed import destroy_model_parallel
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import MultiModalData, SampleLogprobs
|
||||
from vllm.multimodal import MultiModalData
|
||||
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -24,6 +26,7 @@ _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
|
||||
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
|
||||
|
||||
# Multi modal related
|
||||
# You can use `.buildkite/download-images.sh` to download the assets
|
||||
_PIXEL_VALUES_FILES = [
|
||||
os.path.join(_TEST_DIR, "images", filename) for filename in
|
||||
["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
|
||||
@ -89,17 +92,23 @@ def hf_images() -> List[Image.Image]:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def vllm_images(request) -> "torch.Tensor":
|
||||
def vllm_images(request) -> List[MultiModalData]:
|
||||
vision_language_config = request.getfixturevalue("model_and_config")[1]
|
||||
all_images = []
|
||||
if vision_language_config.image_input_type == (
|
||||
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
|
||||
filenames = _IMAGE_FEATURES_FILES
|
||||
return [
|
||||
ImageFeatureData(torch.load(filename))
|
||||
for filename in _IMAGE_FEATURES_FILES
|
||||
]
|
||||
else:
|
||||
filenames = _PIXEL_VALUES_FILES
|
||||
for filename in filenames:
|
||||
all_images.append(torch.load(filename))
|
||||
return torch.concat(all_images, dim=0)
|
||||
return [
|
||||
ImagePixelData(Image.open(filename)) for filename in _IMAGE_FILES
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def vllm_image_tensors(request) -> List[torch.Tensor]:
|
||||
return [torch.load(filename) for filename in _PIXEL_VALUES_FILES]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -392,23 +401,17 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
images: Optional[List[MultiModalData]] = None,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
|
||||
prompt_inputs: List[TextPrompt] = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
prompt = TextPrompt(prompt=prompt)
|
||||
if images is not None:
|
||||
prompt["multi_modal_data"] = MultiModalData(
|
||||
type=MultiModalData.Type.IMAGE,
|
||||
data=images[i:i + 1],
|
||||
)
|
||||
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
inputs[i]["multi_modal_data"] = image
|
||||
|
||||
prompt_inputs.append(prompt)
|
||||
|
||||
req_outputs = self.model.generate(prompt_inputs,
|
||||
req_outputs = self.model.generate(inputs,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
outputs: List[Tuple[List[List[int]], List[str]]] = []
|
||||
@ -447,7 +450,7 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
images: Optional[List[MultiModalData]] = None,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs = self.generate(prompts, greedy_params, images=images)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import gc
|
||||
from dataclasses import fields
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -9,36 +9,50 @@ from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
|
||||
|
||||
def iter_llava_configs(model_name: str):
|
||||
image_hw_to_feature_size = {
|
||||
(336, 336): 576,
|
||||
}
|
||||
|
||||
for (h, w), f in image_hw_to_feature_size.items():
|
||||
for input_type, input_shape in [
|
||||
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
|
||||
(VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, f, 1024)),
|
||||
]:
|
||||
yield (model_name,
|
||||
VisionLanguageConfig(image_input_type=input_type,
|
||||
image_feature_size=f,
|
||||
image_token_id=32000,
|
||||
image_input_shape=input_shape,
|
||||
image_processor=model_name,
|
||||
image_processor_revision=None))
|
||||
|
||||
|
||||
model_and_vl_config = [
|
||||
("llava-hf/llava-1.5-7b-hf",
|
||||
VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
|
||||
image_feature_size=576,
|
||||
image_token_id=32000,
|
||||
image_input_shape=(1, 3, 336, 336))),
|
||||
("llava-hf/llava-1.5-7b-hf",
|
||||
VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.ImageInputType.IMAGE_FEATURES,
|
||||
image_feature_size=576,
|
||||
image_token_id=32000,
|
||||
image_input_shape=(1, 576, 1024)))
|
||||
*iter_llava_configs("llava-hf/llava-1.5-7b-hf"),
|
||||
# Not enough memory
|
||||
# *iter_llava_configs("llava-hf/llava-1.5-13b-hf"),
|
||||
]
|
||||
|
||||
|
||||
def as_dict(vision_language_config: VisionLanguageConfig) -> Dict:
|
||||
def as_dict(vlm_config: VisionLanguageConfig) -> Dict[str, Any]:
|
||||
"""Flatten vision language config to pure args.
|
||||
|
||||
Compatible with what llm entrypoint expects.
|
||||
"""
|
||||
result = {}
|
||||
for field in fields(vision_language_config):
|
||||
value = getattr(vision_language_config, field.name)
|
||||
for field in fields(vlm_config):
|
||||
value = getattr(vlm_config, field.name)
|
||||
if isinstance(value, Enum):
|
||||
result[field.name] = value.name.lower()
|
||||
elif isinstance(value, tuple):
|
||||
result[field.name] = ",".join([str(item) for item in value])
|
||||
else:
|
||||
result[field.name] = value
|
||||
|
||||
result["disable_image_processor"] = vlm_config.image_processor is None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -67,18 +81,19 @@ def sanitize_vllm_output(vllm_output: Tuple[List[int], str],
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
|
||||
vllm_image_prompts, vllm_images, model_and_config: tuple,
|
||||
dtype: str, max_tokens: int, worker_use_ray: bool) -> None:
|
||||
vllm_image_prompts, vllm_images, model_and_config, dtype: str,
|
||||
max_tokens: int, worker_use_ray: bool) -> None:
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the raw images as input.
|
||||
For vllm runner, we provide image tensors and corresponding
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalData objects and corresponding
|
||||
vision language config as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
model_id, vision_language_config = model_and_config
|
||||
|
||||
hf_model = hf_runner(model_id, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
|
||||
max_tokens,
|
||||
@ -88,6 +103,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
|
||||
vllm_model = vllm_runner(model_id,
|
||||
dtype=dtype,
|
||||
worker_use_ray=worker_use_ray,
|
||||
enforce_eager=True,
|
||||
**as_dict(vision_language_config))
|
||||
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
||||
max_tokens,
|
||||
@ -105,3 +121,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
|
||||
# (Requires multiple GPUs)
|
||||
|
0
tests/multimodal/__init__.py
Normal file
0
tests/multimodal/__init__.py
Normal file
98
tests/multimodal/test_processor.py
Normal file
98
tests/multimodal/test_processor.py
Normal file
@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from transformers import CLIPImageProcessor
|
||||
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import ImagePixelData
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half", "bfloat16", "float"])
|
||||
def test_clip_image_processor(hf_images, dtype):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
IMAGE_HEIGHT = IMAGE_WIDTH = 33
|
||||
|
||||
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
|
||||
assert isinstance(hf_processor, CLIPImageProcessor)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL_NAME,
|
||||
tokenizer=MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype=dtype,
|
||||
revision=None,
|
||||
)
|
||||
vlm_config = VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
|
||||
image_token_id=32000,
|
||||
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
|
||||
image_feature_size=576,
|
||||
image_processor=MODEL_NAME,
|
||||
image_processor_revision=None,
|
||||
)
|
||||
|
||||
for image in hf_images:
|
||||
hf_result = hf_processor.preprocess(
|
||||
image,
|
||||
return_tensors="np",
|
||||
)
|
||||
vllm_result = MULTIMODAL_REGISTRY.process_input(
|
||||
ImagePixelData(image),
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
|
||||
assert hf_result.keys() == vllm_result.keys()
|
||||
for key, hf_arr in hf_result.items():
|
||||
vllm_arr: np.ndarray = vllm_result[key].numpy()
|
||||
|
||||
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
|
||||
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_image_pixel_types(hf_images, vllm_image_tensors, dtype):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
IMAGE_HEIGHT = IMAGE_WIDTH = 33
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL_NAME,
|
||||
tokenizer=MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype=dtype,
|
||||
revision=None,
|
||||
)
|
||||
vlm_config = VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
|
||||
image_token_id=32000,
|
||||
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
|
||||
image_feature_size=576,
|
||||
image_processor=MODEL_NAME,
|
||||
image_processor_revision=None,
|
||||
)
|
||||
|
||||
for image, tensor in zip(hf_images, vllm_image_tensors):
|
||||
image_result = MULTIMODAL_REGISTRY.process_input(
|
||||
ImagePixelData(image),
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
tensor_result = MULTIMODAL_REGISTRY.process_input(
|
||||
ImagePixelData(tensor),
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
|
||||
assert image_result.keys() == tensor_result.keys()
|
||||
for key, image_arr in image_result.items():
|
||||
tensor_arr: np.ndarray = tensor_result[key].numpy()
|
||||
|
||||
assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}"
|
||||
|
||||
# The examples in PR#3042 have slightly different preprocessing from
|
||||
# HuggingFace's LlavaProcessor, causing the test to fail.
|
||||
# assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}"
|
@ -18,9 +18,10 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.multimodal import MultiModalData
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Logprob, MultiModalData
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, random_uuid
|
||||
|
||||
|
20
tests/tokenization/test_image_processor.py
Normal file
20
tests/tokenization/test_image_processor.py
Normal file
@ -0,0 +1,20 @@
|
||||
import pytest
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
from vllm.transformers_utils.image_processor import get_image_processor
|
||||
|
||||
IMAGE_PROCESSOR_NAMES = [
|
||||
"llava-hf/llava-1.5-7b-hf",
|
||||
"llava-hf/llava-v1.6-34b-hf",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("processor_name", IMAGE_PROCESSOR_NAMES)
|
||||
def test_image_processor_revision(processor_name: str):
|
||||
# Assume that "main" branch always exists
|
||||
image_processor = get_image_processor(processor_name, revision="main")
|
||||
assert isinstance(image_processor, BaseImageProcessor)
|
||||
|
||||
# Assume that "never" branch always does not exist
|
||||
with pytest.raises(OSError, match='not a valid git identifier'):
|
||||
get_image_processor(processor_name, revision="never")
|
@ -1094,10 +1094,12 @@ class VisionLanguageConfig:
|
||||
# worst case scenario (biggest supported resolution).
|
||||
image_input_shape: tuple
|
||||
image_feature_size: int
|
||||
# The image processor to load from HuggingFace
|
||||
image_processor: Optional[str]
|
||||
image_processor_revision: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def get_image_input_enum_type(
|
||||
cls, value: str) -> "VisionLanguageConfig.ImageInputType":
|
||||
def get_image_input_enum_type(cls, value: str) -> ImageInputType:
|
||||
"""Get the image input type from a string."""
|
||||
try:
|
||||
return cls.ImageInputType[value.upper()]
|
||||
|
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@ -80,6 +81,10 @@ class EngineArgs:
|
||||
image_token_id: Optional[int] = None
|
||||
image_input_shape: Optional[str] = None
|
||||
image_feature_size: Optional[int] = None
|
||||
image_processor: Optional[str] = None
|
||||
image_processor_revision: Optional[str] = None
|
||||
disable_image_processor: bool = False
|
||||
|
||||
scheduler_delay_factor: float = 0.0
|
||||
enable_chunked_prefill: bool = False
|
||||
|
||||
@ -98,6 +103,53 @@ class EngineArgs:
|
||||
if self.tokenizer is None:
|
||||
self.tokenizer = self.model
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args_for_vlm(
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument('--image-input-type',
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
choices=[
|
||||
t.name.lower()
|
||||
for t in VisionLanguageConfig.ImageInputType
|
||||
],
|
||||
help=('The image input type passed into vLLM.'))
|
||||
parser.add_argument('--image-token-id',
|
||||
type=int,
|
||||
default=None,
|
||||
help=('Input id for image token.'))
|
||||
parser.add_argument(
|
||||
'--image-input-shape',
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help=('The biggest image input shape (worst for memory footprint) '
|
||||
'given an input type. Only used for vLLM\'s profile_run.'))
|
||||
parser.add_argument(
|
||||
'--image-feature-size',
|
||||
type=int,
|
||||
default=None,
|
||||
help=('The image feature size along the context dimension.'))
|
||||
parser.add_argument(
|
||||
'--image-processor',
|
||||
type=str,
|
||||
default=EngineArgs.image_processor,
|
||||
help='Name or path of the huggingface image processor to use. '
|
||||
'If unspecified, model name or path will be used.')
|
||||
parser.add_argument(
|
||||
'--image-processor-revision',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Revision of the huggingface image processor version to use. '
|
||||
'It can be a branch name, a tag name, or a commit id. '
|
||||
'If unspecified, will use the default version.')
|
||||
parser.add_argument(
|
||||
'--disable-image-processor',
|
||||
action='store_true',
|
||||
help='Disables the use of image processor, even if one is defined '
|
||||
'for the model on huggingface.')
|
||||
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
@ -113,7 +165,8 @@ class EngineArgs:
|
||||
'--tokenizer',
|
||||
type=nullable_str,
|
||||
default=EngineArgs.tokenizer,
|
||||
help='Name or path of the huggingface tokenizer to use.')
|
||||
help='Name or path of the huggingface tokenizer to use. '
|
||||
'If unspecified, model name or path will be used.')
|
||||
parser.add_argument(
|
||||
'--skip-tokenizer-init',
|
||||
action='store_true',
|
||||
@ -136,9 +189,9 @@ class EngineArgs:
|
||||
'--tokenizer-revision',
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help='The specific tokenizer version to use. It can be a branch '
|
||||
'name, a tag name, or a commit id. If unspecified, will use '
|
||||
'the default version.')
|
||||
help='Revision of the huggingface tokenizer to use. '
|
||||
'It can be a branch name, a tag name, or a commit id. '
|
||||
'If unspecified, will use the default version.')
|
||||
parser.add_argument(
|
||||
'--tokenizer-mode',
|
||||
type=str,
|
||||
@ -445,31 +498,10 @@ class EngineArgs:
|
||||
default=EngineArgs.device,
|
||||
choices=["auto", "cuda", "neuron", "cpu"],
|
||||
help='Device type for vLLM execution.')
|
||||
|
||||
# Related to Vision-language models such as llava
|
||||
parser.add_argument(
|
||||
'--image-input-type',
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
choices=[
|
||||
t.name.lower() for t in VisionLanguageConfig.ImageInputType
|
||||
],
|
||||
help=('The image input type passed into vLLM. '
|
||||
'Should be one of "pixel_values" or "image_features".'))
|
||||
parser.add_argument('--image-token-id',
|
||||
type=int,
|
||||
default=None,
|
||||
help=('Input id for image token.'))
|
||||
parser.add_argument(
|
||||
'--image-input-shape',
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help=('The biggest image input shape (worst for memory footprint) '
|
||||
'given an input type. Only used for vLLM\'s profile_run.'))
|
||||
parser.add_argument(
|
||||
'--image-feature-size',
|
||||
type=int,
|
||||
default=None,
|
||||
help=('The image feature size along the context dimension.'))
|
||||
parser = EngineArgs.add_cli_args_for_vlm(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--scheduler-delay-factor',
|
||||
type=float,
|
||||
@ -488,7 +520,6 @@ class EngineArgs:
|
||||
default=EngineArgs.speculative_model,
|
||||
help=
|
||||
'The name of the draft model to be used in speculative decoding.')
|
||||
|
||||
parser.add_argument(
|
||||
'--num-speculative-tokens',
|
||||
type=int,
|
||||
@ -666,12 +697,27 @@ class EngineArgs:
|
||||
raise ValueError(
|
||||
'Specify `image_token_id`, `image_input_shape` and '
|
||||
'`image_feature_size` together with `image_input_type`.')
|
||||
|
||||
if self.image_processor is None:
|
||||
self.image_processor = self.model
|
||||
if self.disable_image_processor:
|
||||
if self.image_processor != self.model:
|
||||
warnings.warn(
|
||||
"You've specified an image processor "
|
||||
f"({self.image_processor}) but also disabled "
|
||||
"it via `--disable-image-processor`.",
|
||||
stacklevel=2)
|
||||
|
||||
self.image_processor = None
|
||||
|
||||
vision_language_config = VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.
|
||||
get_image_input_enum_type(self.image_input_type),
|
||||
image_token_id=self.image_token_id,
|
||||
image_input_shape=str_to_int_tuple(self.image_input_shape),
|
||||
image_feature_size=self.image_feature_size,
|
||||
image_processor=self.image_processor,
|
||||
image_processor_revision=self.image_processor_revision,
|
||||
)
|
||||
else:
|
||||
vision_language_config = None
|
||||
@ -734,3 +780,7 @@ def _engine_args_parser():
|
||||
def _async_engine_args_parser():
|
||||
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
|
||||
async_args_only=True)
|
||||
|
||||
|
||||
def _vlm_engine_args_parser():
|
||||
return EngineArgs.add_cli_args_for_vlm(argparse.ArgumentParser())
|
||||
|
@ -14,7 +14,6 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import MultiModalData
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, deprecate_kwargs
|
||||
|
||||
@ -164,7 +163,6 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -177,7 +175,6 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -191,7 +188,6 @@ class LLM:
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -205,7 +201,6 @@ class LLM:
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -217,7 +212,6 @@ class LLM:
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -236,7 +230,6 @@ class LLM:
|
||||
|
||||
@deprecate_kwargs("prompts",
|
||||
"prompt_token_ids",
|
||||
"multi_modal_data",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'inputs' parameter "
|
||||
"instead.")
|
||||
@ -249,7 +242,6 @@ class LLM:
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
@ -281,11 +273,10 @@ class LLM:
|
||||
"LLM.generate() is only supported for generation models "
|
||||
"(XForCausalLM).")
|
||||
|
||||
if prompt_token_ids is not None or multi_modal_data is not None:
|
||||
if prompt_token_ids is not None:
|
||||
inputs = self._convert_v1_inputs(
|
||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
else:
|
||||
inputs = cast(
|
||||
@ -314,7 +305,6 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -327,7 +317,6 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -341,7 +330,6 @@ class LLM:
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -355,7 +343,6 @@ class LLM:
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -367,7 +354,6 @@ class LLM:
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -386,7 +372,6 @@ class LLM:
|
||||
|
||||
@deprecate_kwargs("prompts",
|
||||
"prompt_token_ids",
|
||||
"multi_modal_data",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'inputs' parameter "
|
||||
"instead.")
|
||||
@ -399,7 +384,6 @@ class LLM:
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
@ -430,11 +414,10 @@ class LLM:
|
||||
"LLM.encode() is only supported for embedding models (XModel)."
|
||||
)
|
||||
|
||||
if prompt_token_ids is not None or multi_modal_data is not None:
|
||||
if prompt_token_ids is not None:
|
||||
inputs = self._convert_v1_inputs(
|
||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
else:
|
||||
inputs = cast(
|
||||
@ -459,7 +442,6 @@ class LLM:
|
||||
self,
|
||||
prompts: Optional[Union[str, List[str]]],
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
|
||||
multi_modal_data: Optional[MultiModalData],
|
||||
):
|
||||
# skip_tokenizer_init is now checked in engine
|
||||
|
||||
@ -499,9 +481,6 @@ class LLM:
|
||||
else:
|
||||
raise AssertionError
|
||||
|
||||
if multi_modal_data is not None:
|
||||
item["multi_modal_data"] = multi_modal_data
|
||||
|
||||
inputs.append(item)
|
||||
|
||||
return inputs
|
||||
|
@ -17,6 +17,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import get_dummy_image_data
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .vlm_base import VisionLanguageModelBase
|
||||
@ -82,6 +84,9 @@ class LlavaImageFeatureInputs(TypedDict):
|
||||
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_feature_input()
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input()
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
|
||||
class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
|
||||
def __init__(self,
|
||||
@ -131,30 +136,41 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, data: object) -> Optional[LlavaImageInputs]:
|
||||
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_features = kwargs.pop("image_features", None)
|
||||
|
||||
expected_input_type = self.vision_language_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
if expected_input_type == ImageInputType.PIXEL_VALUES:
|
||||
if not isinstance(data, torch.Tensor):
|
||||
raise TypeError("Image pixel vector should be a tensor, "
|
||||
f"but received type: {type(data)}")
|
||||
if image_features is not None:
|
||||
raise ValueError(
|
||||
"Expected pixel values but got image features")
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values")
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_image_data(data),
|
||||
data=self._validate_image_data(pixel_values),
|
||||
)
|
||||
elif expected_input_type == ImageInputType.IMAGE_FEATURES:
|
||||
if not isinstance(data, torch.Tensor):
|
||||
raise TypeError("Image feature vector should be a tensor, "
|
||||
f"but received type: {type(data)}")
|
||||
|
||||
if expected_input_type == ImageInputType.IMAGE_FEATURES:
|
||||
if pixel_values is not None:
|
||||
raise ValueError(
|
||||
"Expected image features but got pixel values")
|
||||
if image_features is None:
|
||||
return None
|
||||
|
||||
if not isinstance(image_features, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image features")
|
||||
|
||||
return LlavaImageFeatureInputs(
|
||||
type="image_features",
|
||||
data=self._validate_image_data(data),
|
||||
data=self._validate_image_data(image_features),
|
||||
)
|
||||
|
||||
return None
|
||||
@ -201,12 +217,14 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
image_input: Optional[torch.Tensor] = None) -> SamplerOutput:
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
"""Run forward pass for Llava 1.5.
|
||||
|
||||
One key thing to understand is the `input_ids` already accounts for the
|
||||
@ -227,10 +245,10 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
This way, the `positions` and `attn_metadata` are consistent
|
||||
with the `input_ids`.
|
||||
|
||||
The model takes two types of image inputs:
|
||||
The model takes two types of image inputs:
|
||||
PIXEL_VALUES and IMAGE_FEATURES.
|
||||
The following shows how each maps to huggingface implementation.
|
||||
PIXEL_VALUES:
|
||||
PIXEL_VALUES:
|
||||
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
|
||||
IMAGE_FEATURES:
|
||||
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
|
||||
@ -239,14 +257,15 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
Args:
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
image_input: A batch of image inputs.
|
||||
For PIXEL_VALUES, expecting [1, 3, 336, 336].
|
||||
For IMAGE_FEATURES, expecting [1, 576, 1024].
|
||||
pixel_values: For PIXEL_VALUES, expects a batch with shape
|
||||
[1, 3, 336, 336].
|
||||
image_features: For IMAGE_FEATURES, expects a batch with shape
|
||||
[1, 576, 1024].
|
||||
"""
|
||||
parsed_image_input = self._parse_and_validate_image_input(image_input)
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if parsed_image_input is not None:
|
||||
vision_embeddings = self._process_image_input(parsed_image_input)
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = _merge_vision_embeddings(
|
||||
|
7
vllm/multimodal/__init__.py
Normal file
7
vllm/multimodal/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from .base import MultiModalData, MultiModalPlugin
|
||||
from .registry import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
|
||||
__all__ = [
|
||||
"MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY",
|
||||
"MultiModalRegistry"
|
||||
]
|
126
vllm/multimodal/base.py
Normal file
126
vllm/multimodal/base.py
Normal file
@ -0,0 +1,126 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
|
||||
TypeVar)
|
||||
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalData:
|
||||
"""
|
||||
Base class that contains multi-modal data.
|
||||
|
||||
To add a new modality, add a new file under ``multimodal`` directory.
|
||||
|
||||
In this new file, subclass :class:`~MultiModalData` and
|
||||
:class:`~MultiModalPlugin`.
|
||||
|
||||
Finally, register the new plugin to
|
||||
:const:`vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
This enables models to call :meth:`MultiModalRegistry.register_input` for
|
||||
the new modality.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
D = TypeVar("D", bound=MultiModalData)
|
||||
N = TypeVar("N", bound=Type["nn.Module"])
|
||||
|
||||
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
|
||||
Dict[str, "torch.Tensor"]]
|
||||
"""Return a dictionary to be passed as keyword arguments to
|
||||
:meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers
|
||||
and processors in HuggingFace Transformers."""
|
||||
|
||||
|
||||
class MultiModalPlugin(ABC, Generic[D]):
|
||||
"""
|
||||
Base class that defines data processing logic for a specific modality.
|
||||
|
||||
In particular, we adopt a registry pattern to dispatch data processing
|
||||
according to the model being used (considering that different models may
|
||||
process the same data differently). This registry is in turn used by
|
||||
:class:`~MultiModalRegistry` which acts at a higher level
|
||||
(i.e., the modality of the data).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]:
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
return get_model_architecture(model_config)[0]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._input_processors: Dict[Type["nn.Module"],
|
||||
MultiModalInputProcessor[D]] = {}
|
||||
|
||||
@abstractmethod
|
||||
def get_data_type(self) -> Type[D]:
|
||||
"""
|
||||
Get the modality (subclass of :class:`~MultiModalData`) served by
|
||||
this plugin.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _default_input_processor(
|
||||
self, data: D, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
|
||||
"""Return a dictionary to be passed as keyword arguments to
|
||||
:meth:`torch.nn.Module.forward`. This is similar in concept to
|
||||
tokenizers and processors in HuggingFace Transformers.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def register_input_processor(self,
|
||||
processor: Optional[
|
||||
MultiModalInputProcessor[D]] = None):
|
||||
"""
|
||||
Register an input processor to a model class.
|
||||
|
||||
When the model receives input data that matches the modality served by
|
||||
this plugin (see :meth:`get_data_type`), the provided input processor is
|
||||
applied to preprocess the data. If `None` is provided, then the default
|
||||
input processor is applied instead.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._input_processors:
|
||||
logger.warning(
|
||||
"Model class %s already has an input processor "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._input_processors[model_cls] = processor \
|
||||
or self._default_input_processor
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def process_input(
|
||||
self, data: D, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
|
||||
"""
|
||||
Apply an input processor to a :class:`~MultiModalData` instance passed
|
||||
to the model.
|
||||
|
||||
The model is identified by ``model_config``. ``vlm_config`` is
|
||||
for compatibility purposes and may be merged into ``model_config``
|
||||
in the near future.
|
||||
"""
|
||||
model_cls = self.get_model_cls(model_config)
|
||||
|
||||
processor = self._input_processors.get(model_cls)
|
||||
if processor is None:
|
||||
raise KeyError(f"No input processor in {self} is registered for "
|
||||
f"model class {model_cls.__name__}.")
|
||||
|
||||
return processor(data, model_config, vlm_config)
|
141
vllm/multimodal/image.py
Normal file
141
vllm/multimodal/image.py
Normal file
@ -0,0 +1,141 @@
|
||||
from typing import Dict, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import SequenceData
|
||||
from vllm.transformers_utils.image_processor import cached_get_image_processor
|
||||
|
||||
from .base import MultiModalData, MultiModalPlugin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_dummy_seq_data(seq_len: int,
|
||||
vlm_config: VisionLanguageConfig) -> SequenceData:
|
||||
# NOTE: We assume that <image> token is repeated `image_feature_size` times
|
||||
# and then concatenated with the text prompt
|
||||
# TODO: Enable other ways of inserting the image into the prompt
|
||||
|
||||
token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size
|
||||
token_ids += [0] * (seq_len - vlm_config.image_feature_size)
|
||||
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def _get_dummy_values(vlm_config: VisionLanguageConfig) -> torch.Tensor:
|
||||
if vlm_config.image_processor is None:
|
||||
values_dtype = torch.float16
|
||||
else:
|
||||
values_dtype = torch.uint8
|
||||
|
||||
return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype)
|
||||
|
||||
|
||||
def get_dummy_image_data(
|
||||
seq_len: int,
|
||||
model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
) -> Tuple[SequenceData, MultiModalData]:
|
||||
"""Standard dummy data factory for image data (to be used in
|
||||
:meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`)."""
|
||||
seq_data = _get_dummy_seq_data(seq_len, vlm_config)
|
||||
values = _get_dummy_values(vlm_config)
|
||||
|
||||
config_input_type = vlm_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
fake_mm_data: MultiModalData
|
||||
if config_input_type == ImageInputType.PIXEL_VALUES:
|
||||
fake_mm_data = ImagePixelData(values)
|
||||
elif config_input_type == ImageInputType.IMAGE_FEATURES:
|
||||
fake_mm_data = ImageFeatureData(values)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return seq_data, fake_mm_data
|
||||
|
||||
|
||||
class ImagePixelData(MultiModalData):
|
||||
"""
|
||||
The pixel data of an image. Can be one of:
|
||||
|
||||
- :class:``PIL.Image``: An image object. Requires that a HuggingFace
|
||||
processor is available to the model.
|
||||
- :class:``torch.Tensor``: The raw pixel data which is passed to the model
|
||||
without additional pre-processing.
|
||||
"""
|
||||
|
||||
def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None:
|
||||
if isinstance(image, Image.Image):
|
||||
# So that this class can be created inside the Image context manager
|
||||
image.load()
|
||||
|
||||
self.image = image
|
||||
|
||||
|
||||
class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
|
||||
|
||||
def get_data_type(self) -> Type[ImagePixelData]:
|
||||
return ImagePixelData
|
||||
|
||||
def _get_hf_image_processor(self, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig):
|
||||
if vlm_config is None or vlm_config.image_processor is None:
|
||||
return None
|
||||
|
||||
return cached_get_image_processor(
|
||||
vlm_config.image_processor,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=vlm_config.image_processor_revision,
|
||||
)
|
||||
|
||||
def _default_input_processor(
|
||||
self, data: ImagePixelData, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
|
||||
image = data.image
|
||||
image_processor = self._get_hf_image_processor(model_config,
|
||||
vlm_config)
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
if image_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available"
|
||||
"to process the image object")
|
||||
try:
|
||||
return image_processor.preprocess(image, return_tensors="pt") \
|
||||
.to(model_config.dtype).data
|
||||
except Exception:
|
||||
logger.error("Failed to process image (%s)", image)
|
||||
raise
|
||||
elif isinstance(image, torch.Tensor):
|
||||
pixel_values = image.to(model_config.dtype)
|
||||
|
||||
return {"pixel_values": pixel_values}
|
||||
|
||||
raise TypeError(f"Invalid image type: {type(image)}")
|
||||
|
||||
|
||||
class ImageFeatureData(MultiModalData):
|
||||
"""
|
||||
The feature vector of an image, passed directly to the model.
|
||||
|
||||
This should be the output of the vision tower.
|
||||
"""
|
||||
|
||||
def __init__(self, image_features: torch.Tensor) -> None:
|
||||
self.image_features = image_features
|
||||
|
||||
|
||||
class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):
|
||||
|
||||
def get_data_type(self) -> Type[ImageFeatureData]:
|
||||
return ImageFeatureData
|
||||
|
||||
def _default_input_processor(
|
||||
self, data: ImageFeatureData, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
|
||||
image_features = data.image_features.to(model_config.dtype)
|
||||
|
||||
return {"image_features": image_features}
|
156
vllm/multimodal/registry.py
Normal file
156
vllm/multimodal/registry.py
Normal file
@ -0,0 +1,156 @@
|
||||
import functools
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence,
|
||||
Tuple, Type, TypeVar)
|
||||
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base import MultiModalData, MultiModalPlugin
|
||||
from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
|
||||
ImagePixelPlugin)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
D = TypeVar("D", bound=MultiModalData)
|
||||
N = TypeVar("N", bound=Type["nn.Module"])
|
||||
|
||||
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
|
||||
Dict[str, "torch.Tensor"]]
|
||||
MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig],
|
||||
Tuple["SequenceData", MultiModalData]]
|
||||
|
||||
|
||||
class MultiModalRegistry:
|
||||
"""
|
||||
This registry is used by model runners to dispatch data processing
|
||||
according to its modality and the target model.
|
||||
"""
|
||||
|
||||
DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS
|
||||
) -> None:
|
||||
self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
|
||||
self._dummy_factories_by_model_type: Dict[Type["nn.Module"],
|
||||
MultiModalDummyFactory] = {}
|
||||
|
||||
def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
|
||||
data_type = plugin.get_data_type()
|
||||
|
||||
if data_type in self._plugins_by_data_type:
|
||||
logger.warning(
|
||||
"A plugin is already registered for data type %s, "
|
||||
"and will be overwritten by the new plugin %s.", data_type,
|
||||
plugin)
|
||||
|
||||
self._plugins_by_data_type[data_type] = plugin
|
||||
|
||||
def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]):
|
||||
for typ in data_type.mro():
|
||||
plugin = self._plugins_by_data_type.get(typ)
|
||||
if plugin is not None:
|
||||
return plugin
|
||||
|
||||
msg = f"Unknown multi-modal data type: {data_type}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def register_dummy_data(self, factory: MultiModalDummyFactory):
|
||||
"""
|
||||
Register a dummy data factory to a model class.
|
||||
|
||||
During memory profiling, the provided function is invoked to create
|
||||
dummy data to be inputted into the model. The modality and shape of
|
||||
the dummy data should be an upper bound of what the model would receive
|
||||
at inference time.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._dummy_factories_by_model_type:
|
||||
logger.warning(
|
||||
"Model class %s already has dummy data "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._dummy_factories_by_model_type[model_cls] = factory
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def dummy_data_for_profiling(self, seq_len: int, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig):
|
||||
"""Create dummy data for memory profiling."""
|
||||
model_cls = MultiModalPlugin.get_model_cls(model_config)
|
||||
dummy_factory = self._dummy_factories_by_model_type.get(model_cls)
|
||||
if dummy_factory is None:
|
||||
msg = f"No dummy data defined for model class: {model_cls}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return dummy_factory(seq_len, model_config, vlm_config)
|
||||
|
||||
def register_input(
|
||||
self,
|
||||
data_type: Type[D],
|
||||
processor: Optional[MultiModalInputProcessor[D]] = None):
|
||||
"""
|
||||
Register an input processor for a specific modality to a model class.
|
||||
|
||||
See :meth:`MultiModalPlugin.register_input_processor` for more details.
|
||||
"""
|
||||
return self._get_plugin_for_data_type(data_type) \
|
||||
.register_input_processor(processor)
|
||||
|
||||
def register_image_pixel_input(
|
||||
self,
|
||||
processor: Optional[
|
||||
MultiModalInputProcessor[ImagePixelData]] = None):
|
||||
"""
|
||||
Register an input processor for image pixel data to a model class.
|
||||
|
||||
See :meth:`MultiModalPlugin.register_input_processor` for more details.
|
||||
"""
|
||||
return self.register_input(ImagePixelData, processor)
|
||||
|
||||
def register_image_feature_input(
|
||||
self,
|
||||
processor: Optional[
|
||||
MultiModalInputProcessor[ImageFeatureData]] = None):
|
||||
"""
|
||||
Register an input processor for image feature data to a model class.
|
||||
|
||||
See :meth:`MultiModalPlugin.register_input_processor` for more details.
|
||||
"""
|
||||
return self.register_input(ImageFeatureData, processor)
|
||||
|
||||
def process_input(self, data: MultiModalData, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig):
|
||||
"""
|
||||
Apply an input processor to a :class:`~MultiModalData` instance passed
|
||||
to the model.
|
||||
|
||||
See :meth:`MultiModalPlugin.process_input` for more details.
|
||||
"""
|
||||
return self._get_plugin_for_data_type(type(data)) \
|
||||
.process_input(data, model_config, vlm_config)
|
||||
|
||||
def create_input_processor(self, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig):
|
||||
"""
|
||||
Create an input processor (see :meth:`process_input`) for a
|
||||
specific model.
|
||||
"""
|
||||
return functools.partial(self.process_input,
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config)
|
||||
|
||||
|
||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||
"""The global :class:`~MultiModalRegistry` which is used by model runners."""
|
@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.block import LogicalTokenBlock
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -12,8 +14,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.multimodal import MultiModalData
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
|
||||
@ -398,25 +399,6 @@ class SequenceGroupState:
|
||||
generator: Optional = None # type: ignore
|
||||
|
||||
|
||||
class MultiModalData:
|
||||
"""Multi modal request.
|
||||
|
||||
Args:
|
||||
type: The data type.
|
||||
data: The actual data.
|
||||
The required shape and semantic meaning of it depends on the vision
|
||||
language config of the hosted model.
|
||||
See `VisionLanguageConfig` in `config.py`.
|
||||
"""
|
||||
|
||||
class Type(enum.Enum):
|
||||
IMAGE = enum.auto()
|
||||
|
||||
def __init__(self, type: Type, data: "torch.Tensor"):
|
||||
self.type = type
|
||||
self.data = data
|
||||
|
||||
|
||||
class SequenceGroup:
|
||||
"""A group of sequences that are generated from the same prompt.
|
||||
|
||||
@ -473,7 +455,7 @@ class SequenceGroup:
|
||||
return next(iter(self.seqs_dict.values())).prompt_token_ids
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> Optional[MultiModalData]:
|
||||
def multi_modal_data(self) -> Optional["MultiModalData"]:
|
||||
# All sequences in the group should have the same multi-modal data.
|
||||
# We use the multi-modal data of an arbitrary sequence.
|
||||
return next(iter(self.seqs_dict.values())).multi_modal_data
|
||||
@ -655,7 +637,7 @@ class SequenceGroupMetadata:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
computed_block_nums: Optional[List[int]] = None,
|
||||
state: Optional[SequenceGroupState] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
multi_modal_data: Optional["MultiModalData"] = None,
|
||||
encoder_seq_data: Optional[SequenceData] = None,
|
||||
cross_block_table: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
@ -798,13 +780,13 @@ class SamplerOutput:
|
||||
outputs: List[CompletionSequenceGroupOutput]
|
||||
|
||||
# On-device tensor containing probabilities of each token.
|
||||
sampled_token_probs: Optional["torch.Tensor"] = None
|
||||
sampled_token_probs: Optional[torch.Tensor] = None
|
||||
|
||||
# On-device tensor containing the logprobs of each token.
|
||||
logprobs: Optional["torch.Tensor"] = None
|
||||
|
||||
# On-device tensor containing the sampled token ids.
|
||||
sampled_token_ids: Optional["torch.Tensor"] = None
|
||||
sampled_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# Spec decode metrics populated by workers.
|
||||
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
|
45
vllm/transformers_utils/image_processor.py
Normal file
45
vllm/transformers_utils/image_processor.py
Normal file
@ -0,0 +1,45 @@
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoImageProcessor
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_image_processor(
|
||||
processor_name: str,
|
||||
*args,
|
||||
trust_remote_code: bool = False,
|
||||
revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> BaseImageProcessor:
|
||||
"""Gets an image processor for the given model name via HuggingFace."""
|
||||
try:
|
||||
processor: BaseImageProcessor = AutoImageProcessor.from_pretrained(
|
||||
processor_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the processor class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
|
||||
if not trust_remote_code:
|
||||
err_msg = (
|
||||
"Failed to load the image processor. If the image processor is "
|
||||
"a custom processor not yet available in the HuggingFace "
|
||||
"transformers library, consider setting "
|
||||
"`trust_remote_code=True` in LLM or using the "
|
||||
"`--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
cached_get_image_processor = lru_cache(get_image_processor)
|
@ -1,4 +1,5 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -11,6 +12,7 @@ from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
|
||||
@ -63,6 +65,16 @@ class CPUModelRunner:
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
# Create processor for multi-modal data
|
||||
if self.vision_language_config is not None:
|
||||
self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
|
||||
.create_input_processor(
|
||||
self.model_config,
|
||||
self.vision_language_config,
|
||||
)
|
||||
else:
|
||||
self.multi_modal_input_processor = None
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
|
||||
@ -80,14 +92,15 @@ class CPUModelRunner:
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], Dict[
|
||||
str, torch.Tensor]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
multi_modal_input_list: List[torch.Tensor] = []
|
||||
multi_modal_kwargs_list: Dict[str,
|
||||
List[torch.Tensor]] = defaultdict(list)
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
@ -108,9 +121,17 @@ class CPUModelRunner:
|
||||
# is always the first token in the sequence.
|
||||
input_positions.extend(list(range(computed_len, seq_len)))
|
||||
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
multi_modal_input_list.append(
|
||||
seq_group_metadata.multi_modal_data.data)
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data is not None:
|
||||
# Process multi-modal data
|
||||
if self.multi_modal_input_processor is None:
|
||||
raise ValueError(
|
||||
"Multi-modal inputs are only supported by "
|
||||
"vision language models.")
|
||||
|
||||
mm_kwargs = self.multi_modal_input_processor(mm_data)
|
||||
for k, v in mm_kwargs.items():
|
||||
multi_modal_kwargs_list[k].append(v)
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
@ -134,14 +155,10 @@ class CPUModelRunner:
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if multi_modal_input_list:
|
||||
assert self.vision_language_config, (
|
||||
"Multi-modal inputs are only supported by "
|
||||
"vision language models.")
|
||||
multi_modal_input = torch.cat(multi_modal_input_list,
|
||||
dim=0).to(self.device)
|
||||
else:
|
||||
multi_modal_input = None
|
||||
multi_modal_kwargs = {
|
||||
k: torch.cat(v, dim=0).to(self.device)
|
||||
for k, v in multi_modal_kwargs_list.items()
|
||||
}
|
||||
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
|
||||
@ -167,7 +184,7 @@ class CPUModelRunner:
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_input)
|
||||
multi_modal_kwargs)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
@ -257,8 +274,8 @@ class CPUModelRunner:
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||
Optional[torch.Tensor]]:
|
||||
multi_modal_input = None
|
||||
Optional[Dict[str, torch.Tensor]]]:
|
||||
multi_modal_kwargs = None
|
||||
if self.is_driver_worker:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
@ -266,7 +283,7 @@ class CPUModelRunner:
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_input
|
||||
multi_modal_kwargs
|
||||
) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
@ -307,7 +324,7 @@ class CPUModelRunner:
|
||||
)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata,
|
||||
sampling_metadata, multi_modal_input)
|
||||
sampling_metadata, multi_modal_kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
|
@ -90,7 +90,7 @@ class EmbeddingModelRunner(ModelRunner):
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
|
||||
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
||||
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
|
||||
if self.is_driver_worker:
|
||||
assert seq_group_metadata_list is not None
|
||||
# Prepare input tensors.
|
||||
@ -102,7 +102,7 @@ class EmbeddingModelRunner(ModelRunner):
|
||||
_,
|
||||
lora_mapping,
|
||||
lora_requests,
|
||||
multi_modal_input,
|
||||
multi_modal_kwargs,
|
||||
slot_mapping,
|
||||
num_prefill_tokens,
|
||||
num_decode_tokens,
|
||||
@ -117,7 +117,7 @@ class EmbeddingModelRunner(ModelRunner):
|
||||
"input_positions": input_positions,
|
||||
"lora_requests": lora_requests,
|
||||
"lora_mapping": lora_mapping,
|
||||
"multi_modal_input": multi_modal_input,
|
||||
"multi_modal_kwargs": multi_modal_kwargs,
|
||||
"num_prefill_tokens": num_prefill_tokens,
|
||||
"num_decode_tokens": num_decode_tokens,
|
||||
"slot_mapping": slot_mapping,
|
||||
@ -132,7 +132,7 @@ class EmbeddingModelRunner(ModelRunner):
|
||||
input_positions = metadata_dict.pop("input_positions")
|
||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||
lora_requests = metadata_dict.pop("lora_requests")
|
||||
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
||||
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
|
||||
if metadata_dict:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
@ -143,7 +143,7 @@ class EmbeddingModelRunner(ModelRunner):
|
||||
prompt_lens=None)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
|
||||
lora_requests, lora_mapping, multi_modal_input)
|
||||
lora_requests, lora_mapping, multi_modal_kwargs)
|
||||
|
||||
def _prepare_pooling(
|
||||
self,
|
||||
|
@ -1,5 +1,6 @@
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -18,9 +19,9 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
|
||||
is_pin_memory_available, make_tensor_with_pad)
|
||||
|
||||
@ -44,7 +45,7 @@ class ModelInput(NamedTuple):
|
||||
query_lens: List[int]
|
||||
lora_mapping: Optional[LoRAMapping]
|
||||
lora_requests: Set[LoRARequest]
|
||||
multi_modal_input: Optional[torch.Tensor]
|
||||
multi_modal_kwargs: Dict[str, torch.Tensor]
|
||||
slot_mapping: torch.Tensor
|
||||
num_prefill_tokens: int
|
||||
num_decode_tokens: int
|
||||
@ -60,7 +61,7 @@ class ModelInput(NamedTuple):
|
||||
query_lens=[],
|
||||
lora_mapping=None,
|
||||
lora_requests=set(),
|
||||
multi_modal_input=None,
|
||||
multi_modal_kwargs={},
|
||||
slot_mapping=torch.empty(0, device=device),
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=0,
|
||||
@ -122,6 +123,16 @@ class ModelRunner:
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
# Create processor for multi-modal data
|
||||
if self.vision_language_config is not None:
|
||||
self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
|
||||
.create_input_processor(
|
||||
self.model_config,
|
||||
self.vision_language_config,
|
||||
)
|
||||
else:
|
||||
self.multi_modal_input_processor = None
|
||||
|
||||
# Lazy initialization
|
||||
self.model: nn.Module # Set after load_model
|
||||
# Set if the backend is flashinfer.
|
||||
@ -242,7 +253,8 @@ class ModelRunner:
|
||||
context_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
multi_modal_input_list: List[torch.Tensor] = []
|
||||
multi_modal_kwargs_list: Dict[str,
|
||||
List[torch.Tensor]] = defaultdict(list)
|
||||
decode_only = True
|
||||
num_prefills = 0
|
||||
num_prefill_tokens = 0
|
||||
@ -417,9 +429,17 @@ class ModelRunner:
|
||||
and seq_group_metadata.sampling_params.prompt_logprobs
|
||||
else 1))
|
||||
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
multi_modal_input_list.append(
|
||||
seq_group_metadata.multi_modal_data.data)
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data is not None:
|
||||
# Process multi-modal data
|
||||
if self.multi_modal_input_processor is None:
|
||||
raise ValueError(
|
||||
"Multi-modal inputs are only supported by "
|
||||
"vision language models.")
|
||||
|
||||
mm_kwargs = self.multi_modal_input_processor(mm_data)
|
||||
for k, v in mm_kwargs.items():
|
||||
multi_modal_kwargs_list[k].append(v)
|
||||
|
||||
if _is_block_tables_empty(seq_group_metadata.block_tables):
|
||||
# During memory profiling, the block tables are not
|
||||
@ -508,16 +528,6 @@ class ModelRunner:
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
if multi_modal_input_list:
|
||||
assert self.vision_language_config, (
|
||||
"Multi-modal inputs are only supported by "
|
||||
"vision language models.")
|
||||
multi_modal_input = torch.cat(multi_modal_input_list,
|
||||
dim=0).to(self.device)
|
||||
else:
|
||||
multi_modal_input = None
|
||||
|
||||
query_lens_tensor = torch.tensor(query_lens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
@ -614,6 +624,11 @@ class ModelRunner:
|
||||
else:
|
||||
lora_mapping = None
|
||||
|
||||
multi_modal_kwargs = {
|
||||
k: torch.cat(v, dim=0).to(self.device)
|
||||
for k, v in multi_modal_kwargs_list.items()
|
||||
}
|
||||
|
||||
return ModelInput(
|
||||
input_tokens=input_tokens_tensor,
|
||||
input_positions=input_positions_tensor,
|
||||
@ -622,7 +637,7 @@ class ModelRunner:
|
||||
query_lens=query_lens,
|
||||
lora_mapping=lora_mapping,
|
||||
lora_requests=lora_requests,
|
||||
multi_modal_input=multi_modal_input,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
@ -633,7 +648,7 @@ class ModelRunner:
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
||||
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
|
||||
if self.is_driver_worker:
|
||||
assert seq_group_metadata_list is not None
|
||||
# Prepare input tensors.
|
||||
@ -645,7 +660,7 @@ class ModelRunner:
|
||||
query_lens,
|
||||
lora_mapping,
|
||||
lora_requests,
|
||||
multi_modal_input,
|
||||
multi_modal_kwargs,
|
||||
slot_mapping,
|
||||
num_prefill_tokens,
|
||||
num_decode_tokens,
|
||||
@ -662,7 +677,7 @@ class ModelRunner:
|
||||
sampling_metadata.selected_token_indices,
|
||||
"lora_requests": lora_requests,
|
||||
"lora_mapping": lora_mapping,
|
||||
"multi_modal_input": multi_modal_input,
|
||||
"multi_modal_kwargs": multi_modal_kwargs,
|
||||
"num_prefill_tokens": num_prefill_tokens,
|
||||
"num_decode_tokens": num_decode_tokens,
|
||||
"slot_mapping": slot_mapping,
|
||||
@ -679,7 +694,7 @@ class ModelRunner:
|
||||
"selected_token_indices")
|
||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||
lora_requests = metadata_dict.pop("lora_requests")
|
||||
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
||||
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
|
||||
if metadata_dict:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
@ -694,7 +709,7 @@ class ModelRunner:
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata,
|
||||
sampling_metadata, lora_requests, lora_mapping,
|
||||
multi_modal_input)
|
||||
multi_modal_kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@ -703,7 +718,7 @@ class ModelRunner:
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[SamplerOutput]:
|
||||
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
||||
lora_requests, lora_mapping, multi_modal_input
|
||||
lora_requests, lora_mapping, multi_modal_kwargs
|
||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
if self.lora_config:
|
||||
@ -717,15 +732,14 @@ class ModelRunner:
|
||||
model_executable = self.graph_runners[graph_batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
execute_model_kwargs = {
|
||||
"input_ids": input_tokens,
|
||||
"positions": input_positions,
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
if self.vision_language_config:
|
||||
execute_model_kwargs.update({"image_input": multi_modal_input})
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
hidden_states = model_executable(
|
||||
input_ids=input_tokens,
|
||||
positions=input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
**multi_modal_kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
@ -781,16 +795,24 @@ class ModelRunner:
|
||||
# To exercise the worst scenario for GPU memory consumption,
|
||||
# the number of seqs (batch_size) is chosen to maximize the number
|
||||
# of images processed.
|
||||
if self.vision_language_config:
|
||||
model_config = self.model_config
|
||||
vlm_config = self.vision_language_config
|
||||
|
||||
if vlm_config:
|
||||
max_num_seqs = min(
|
||||
max_num_seqs,
|
||||
int(max_num_batched_tokens /
|
||||
self.vision_language_config.image_feature_size))
|
||||
int(max_num_batched_tokens / vlm_config.image_feature_size))
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
seq_data, fake_multi_modal_input = _prepare_fake_inputs(
|
||||
seq_len, self.vision_language_config)
|
||||
|
||||
if vlm_config is None:
|
||||
seq_data = SequenceData([0] * seq_len)
|
||||
dummy_multi_modal_data = None
|
||||
else:
|
||||
seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \
|
||||
.dummy_data_for_profiling(seq_len, model_config, vlm_config)
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
@ -799,7 +821,7 @@ class ModelRunner:
|
||||
block_tables=None,
|
||||
lora_request=dummy_lora_requests_per_seq[group_id]
|
||||
if dummy_lora_requests_per_seq else None,
|
||||
multi_modal_data=fake_multi_modal_input,
|
||||
multi_modal_data=dummy_multi_modal_data,
|
||||
)
|
||||
seqs.append(seq)
|
||||
|
||||
@ -1034,24 +1056,6 @@ def _get_graph_batch_size(batch_size: int) -> int:
|
||||
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
||||
|
||||
|
||||
def _prepare_fake_inputs(
|
||||
seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
|
||||
"""Prepare fake inputs for profile run."""
|
||||
if vision_language_config:
|
||||
prompt_tokens = [
|
||||
vision_language_config.image_token_id
|
||||
] * vision_language_config.image_feature_size + [0] * (
|
||||
seq_len - vision_language_config.image_feature_size)
|
||||
fake_image_input = MultiModalData(
|
||||
type=MultiModalData.Type.IMAGE,
|
||||
data=torch.zeros(vision_language_config.image_input_shape,
|
||||
dtype=torch.float16))
|
||||
else:
|
||||
prompt_tokens = [0] * seq_len
|
||||
fake_image_input = None
|
||||
return SequenceData(prompt_tokens), fake_image_input
|
||||
|
||||
|
||||
def _is_block_tables_empty(block_tables: Union[None, Dict]):
|
||||
"""
|
||||
Check if block_tables is None or a dictionary with all None values.
|
||||
|
Loading…
x
Reference in New Issue
Block a user