[Frontend] Add OpenAI Vision API Support (#5237)
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
ca3ea51bde
commit
7a9cb294ae
@ -3,7 +3,7 @@
|
|||||||
Using VLMs
|
Using VLMs
|
||||||
==========
|
==========
|
||||||
|
|
||||||
This document shows you how to run and serve Vision Language Models (VLMs) using vLLM.
|
vLLM provides experimental support for Vision Language Models (VLMs). This document shows you how to run and serve these models using vLLM.
|
||||||
|
|
||||||
Engine Arguments
|
Engine Arguments
|
||||||
----------------
|
----------------
|
||||||
@ -54,3 +54,69 @@ For now, we only support a single image per text prompt. To pass an image to the
|
|||||||
print(generated_text)
|
print(generated_text)
|
||||||
|
|
||||||
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_.
|
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_.
|
||||||
|
|
||||||
|
Online OpenAI Vision API Compatible Inference
|
||||||
|
----------------------------------------------
|
||||||
|
|
||||||
|
You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API <https://platform.openai.com/docs/guides/vision>`_.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Currently, vLLM supports only **single** ``image_url`` input per ``messages``. Support for multi-image inputs will be
|
||||||
|
added in the future.
|
||||||
|
|
||||||
|
Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with vLLM API server.
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
Since OpenAI Vision API is based on `Chat <https://platform.openai.com/docs/api-reference/chat>`_ API, a chat template
|
||||||
|
is **required** to launch the API server if the model's tokenizer does not come with one. In this example, we use the
|
||||||
|
HuggingFace Llava chat template that you can find in the example folder `here <https://github.com/vllm-project/vllm/blob/main/examples/template_llava.jinja>`_.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python -m vllm.entrypoints.openai.api_server \
|
||||||
|
--model llava-hf/llava-1.5-7b-hf \
|
||||||
|
--image-input-type pixel_values \
|
||||||
|
--image-token-id 32000 \
|
||||||
|
--image-input-shape 1,3,336,336 \
|
||||||
|
--image-feature-size 576 \
|
||||||
|
--chat-template template_llava.jinja
|
||||||
|
|
||||||
|
To consume the server, you can use the OpenAI client like in the example below:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
openai_api_key = "EMPTY"
|
||||||
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
chat_response = client.chat.completions.create(
|
||||||
|
model="llava-hf/llava-1.5-7b-hf",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
print("Chat response:", chat_response)
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The prompt formatting with the image token ``<image>`` is not needed when serving VLMs with the API server since the prompt will be
|
||||||
|
processed automatically by the server.
|
||||||
|
@ -30,6 +30,8 @@ Please see the [OpenAI API Reference](https://platform.openai.com/docs/api-refer
|
|||||||
- Chat: `tools`, and `tool_choice`.
|
- Chat: `tools`, and `tool_choice`.
|
||||||
- Completions: `suffix`.
|
- Completions: `suffix`.
|
||||||
|
|
||||||
|
vLLM also provides experimental support for OpenAI Vision API compatible inference. See more details in [Using VLMs](../models/vlm.rst).
|
||||||
|
|
||||||
## Extra Parameters
|
## Extra Parameters
|
||||||
vLLM supports a set of parameters that are not part of the OpenAI API.
|
vLLM supports a set of parameters that are not part of the OpenAI API.
|
||||||
In order to use them, you can pass them as extra parameters in the OpenAI client.
|
In order to use them, you can pass them as extra parameters in the OpenAI client.
|
||||||
@ -120,4 +122,4 @@ It is the callers responsibility to prompt the model with the tool information,
|
|||||||
|
|
||||||
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
|
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
|
||||||
|
|
||||||
Please refer to the OpenAI API reference documentation for more information.
|
Please refer to the OpenAI API reference documentation for more information.
|
||||||
|
23
examples/template_llava.jinja
Normal file
23
examples/template_llava.jinja
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
{%- if messages[0]['role'] == 'system' -%}
|
||||||
|
{%- set system_message = messages[0]['content'] -%}
|
||||||
|
{%- set messages = messages[1:] -%}
|
||||||
|
{%- else -%}
|
||||||
|
{% set system_message = '' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{{ bos_token + system_message }}
|
||||||
|
{%- for message in messages -%}
|
||||||
|
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
||||||
|
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- if message['role'] == 'user' -%}
|
||||||
|
{{ 'USER: ' + message['content'] + '\n' }}
|
||||||
|
{%- elif message['role'] == 'assistant' -%}
|
||||||
|
{{ 'ASSISTANT: ' + message['content'] + eos_token + '\n' }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
{{ 'ASSISTANT:' }}
|
||||||
|
{% endif %}
|
286
tests/entrypoints/test_openai_vision.py
Normal file
286
tests/entrypoints/test_openai_vision.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
import ray
|
||||||
|
|
||||||
|
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
|
||||||
|
|
||||||
|
from ..utils import ServerRunner
|
||||||
|
|
||||||
|
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||||
|
LLAVA_CHAT_TEMPLATE = (Path(__file__).parent.parent.parent /
|
||||||
|
"examples/template_llava.jinja")
|
||||||
|
assert LLAVA_CHAT_TEMPLATE.exists()
|
||||||
|
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||||
|
TEST_IMAGE_URLS = [
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
|
||||||
|
]
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.openai
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server():
|
||||||
|
ray.init()
|
||||||
|
server_runner = ServerRunner.remote([
|
||||||
|
"--model",
|
||||||
|
MODEL_NAME,
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"4096",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--image-input-type",
|
||||||
|
"pixel_values",
|
||||||
|
"--image-token-id",
|
||||||
|
"32000",
|
||||||
|
"--image-input-shape",
|
||||||
|
"1,3,336,336",
|
||||||
|
"--image-feature-size",
|
||||||
|
"576",
|
||||||
|
"--chat-template",
|
||||||
|
str(LLAVA_CHAT_TEMPLATE),
|
||||||
|
])
|
||||||
|
ray.get(server_runner.ready.remote())
|
||||||
|
yield server_runner
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def client():
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
base_url="http://localhost:8000/v1",
|
||||||
|
api_key="token-abc123",
|
||||||
|
)
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def base64_encoded_image() -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
image_url:
|
||||||
|
encode_image_base64(await ImageFetchAiohttp.fetch_image(image_url))
|
||||||
|
for image_url in TEST_IMAGE_URLS
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||||
|
async def test_single_chat_session_image(server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str, image_url: str):
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
# test single completion
|
||||||
|
chat_completion = await client.chat.completions.create(model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
logprobs=True,
|
||||||
|
top_logprobs=5)
|
||||||
|
assert len(chat_completion.choices) == 1
|
||||||
|
|
||||||
|
choice = chat_completion.choices[0]
|
||||||
|
assert choice.finish_reason == "length"
|
||||||
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
|
completion_tokens=10, prompt_tokens=596, total_tokens=606)
|
||||||
|
|
||||||
|
message = choice.message
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 10
|
||||||
|
assert message.role == "assistant"
|
||||||
|
messages.append({"role": "assistant", "content": message.content})
|
||||||
|
|
||||||
|
# test multi-turn dialogue
|
||||||
|
messages.append({"role": "user", "content": "express your result in json"})
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||||
|
async def test_single_chat_session_image_base64encoded(
|
||||||
|
server, client: openai.AsyncOpenAI, model_name: str, image_url: str,
|
||||||
|
base64_encoded_image: Dict[str, str]):
|
||||||
|
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url":
|
||||||
|
f"data:image/jpeg;base64,{base64_encoded_image[image_url]}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
# test single completion
|
||||||
|
chat_completion = await client.chat.completions.create(model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
logprobs=True,
|
||||||
|
top_logprobs=5)
|
||||||
|
assert len(chat_completion.choices) == 1
|
||||||
|
|
||||||
|
choice = chat_completion.choices[0]
|
||||||
|
assert choice.finish_reason == "length"
|
||||||
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
|
completion_tokens=10, prompt_tokens=596, total_tokens=606)
|
||||||
|
|
||||||
|
message = choice.message
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 10
|
||||||
|
assert message.role == "assistant"
|
||||||
|
messages.append({"role": "assistant", "content": message.content})
|
||||||
|
|
||||||
|
# test multi-turn dialogue
|
||||||
|
messages.append({"role": "user", "content": "express your result in json"})
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||||
|
async def test_chat_streaming_image(server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str, image_url: str):
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
# test single completion
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
output = chat_completion.choices[0].message.content
|
||||||
|
stop_reason = chat_completion.choices[0].finish_reason
|
||||||
|
|
||||||
|
# test streaming
|
||||||
|
stream = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
chunks = []
|
||||||
|
finish_reason_count = 0
|
||||||
|
async for chunk in stream:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta.role:
|
||||||
|
assert delta.role == "assistant"
|
||||||
|
if delta.content:
|
||||||
|
chunks.append(delta.content)
|
||||||
|
if chunk.choices[0].finish_reason is not None:
|
||||||
|
finish_reason_count += 1
|
||||||
|
# finish reason should only return in last block
|
||||||
|
assert finish_reason_count == 1
|
||||||
|
assert chunk.choices[0].finish_reason == stop_reason
|
||||||
|
assert delta.content
|
||||||
|
assert "".join(chunks) == output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||||
|
async def test_multi_image_input(server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str, image_url: str):
|
||||||
|
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
with pytest.raises(openai.BadRequestError): # test multi-image input
|
||||||
|
await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# the server should still work afterwards
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt=[0, 0, 0, 0, 0],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
completion = completion.choices[0].text
|
||||||
|
assert completion is not None and len(completion) >= 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
75
tests/multimodal/test_utils.py
Normal file
75
tests/multimodal/test_utils.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import base64
|
||||||
|
import mimetypes
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm.multimodal.utils import ImageFetchAiohttp
|
||||||
|
|
||||||
|
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||||
|
TEST_IMAGE_URLS = [
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def url_images() -> Dict[str, Image.Image]:
|
||||||
|
return {
|
||||||
|
image_url: await ImageFetchAiohttp.fetch_image(image_url)
|
||||||
|
for image_url in TEST_IMAGE_URLS
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_supported_suffixes() -> Tuple[str, ...]:
|
||||||
|
# We should at least test the file types mentioned in GPT-4 with Vision
|
||||||
|
OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif')
|
||||||
|
|
||||||
|
# Additional file types that are supported by us
|
||||||
|
EXTRA_SUPPORTED_SUFFIXES = ('.bmp', '.tiff')
|
||||||
|
|
||||||
|
return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES
|
||||||
|
|
||||||
|
|
||||||
|
def _image_equals(a: Image.Image, b: Image.Image) -> bool:
|
||||||
|
return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||||
|
@pytest.mark.parametrize("suffix", get_supported_suffixes())
|
||||||
|
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
||||||
|
image_url: str, suffix: str):
|
||||||
|
url_image = url_images[image_url]
|
||||||
|
|
||||||
|
try:
|
||||||
|
mime_type = Image.MIME[Image.registered_extensions()[suffix]]
|
||||||
|
except KeyError:
|
||||||
|
try:
|
||||||
|
mime_type = mimetypes.types_map[suffix]
|
||||||
|
except KeyError:
|
||||||
|
pytest.skip('No MIME type')
|
||||||
|
|
||||||
|
with NamedTemporaryFile(suffix=suffix) as f:
|
||||||
|
try:
|
||||||
|
url_image.save(f.name)
|
||||||
|
except Exception as e:
|
||||||
|
if e.args[0] == 'cannot write mode RGBA as JPEG':
|
||||||
|
pytest.skip('Conversion not supported')
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
base64_image = base64.b64encode(f.read()).decode("utf-8")
|
||||||
|
data_url = f"data:{mime_type};base64,{base64_image}"
|
||||||
|
|
||||||
|
data_image = await ImageFetchAiohttp.fetch_image(data_url)
|
||||||
|
if _image_equals(url_image, Image.open(f)):
|
||||||
|
assert _image_equals(url_image, data_image)
|
||||||
|
else:
|
||||||
|
pass # Lossy format; only check that image can be opened
|
@ -5,7 +5,7 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
|
|||||||
Union)
|
Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig, PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
@ -1119,6 +1119,16 @@ class VisionLanguageConfig:
|
|||||||
f"Expecting to choose from "
|
f"Expecting to choose from "
|
||||||
f"{[x.name for x in cls.ImageInputType]}.") from e
|
f"{[x.name for x in cls.ImageInputType]}.") from e
|
||||||
|
|
||||||
|
#TODO(ywang96): make this a cached property once we refactor the
|
||||||
|
# VisionLanguageConfig class.
|
||||||
|
def get_image_token_text(
|
||||||
|
self, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]:
|
||||||
|
"""Get the image token placeholder text to be inserted into the
|
||||||
|
text prompt and the string representation of the image token id.
|
||||||
|
"""
|
||||||
|
image_token_str = tokenizer.decode(self.image_token_id)
|
||||||
|
return image_token_str * self.image_feature_size, image_token_str
|
||||||
|
|
||||||
def as_cli_args_dict(self) -> Dict[str, Any]:
|
def as_cli_args_dict(self) -> Dict[str, Any]:
|
||||||
"""Flatten vision language config to pure args.
|
"""Flatten vision language config to pure args.
|
||||||
|
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
import codecs
|
import codecs
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List,
|
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
|
||||||
Optional)
|
List, Optional)
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import TypedDict, Union, cast, final
|
from typing import TypedDict, Union, cast, final
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from openai.types.chat import ChatCompletionContentPartTextParam
|
from openai.types.chat import (ChatCompletionContentPartImageParam,
|
||||||
|
ChatCompletionContentPartTextParam)
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionContentPartParam, ChatCompletionLogProb,
|
ChatCompletionContentPartParam, ChatCompletionLogProb,
|
||||||
@ -21,9 +22,13 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
FunctionCall, ToolCall, UsageInfo)
|
FunctionCall, ToolCall, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
OpenAIServing)
|
OpenAIServing)
|
||||||
|
from vllm.inputs import PromptInputs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.guided_decoding import (
|
from vllm.model_executor.guided_decoding import (
|
||||||
get_guided_decoding_logits_processor)
|
get_guided_decoding_logits_processor)
|
||||||
|
from vllm.multimodal.image import ImagePixelData
|
||||||
|
from vllm.multimodal.utils import (async_get_and_parse_image,
|
||||||
|
get_full_image_text_prompt)
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
@ -40,6 +45,8 @@ class ConversationMessage(TypedDict):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ChatMessageParseResult:
|
class ChatMessageParseResult:
|
||||||
messages: List[ConversationMessage]
|
messages: List[ConversationMessage]
|
||||||
|
image_futures: List[Awaitable[ImagePixelData]] = field(
|
||||||
|
default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServingChat(OpenAIServing):
|
class OpenAIServingChat(OpenAIServing):
|
||||||
@ -94,19 +101,76 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
parts: Iterable[ChatCompletionContentPartParam],
|
parts: Iterable[ChatCompletionContentPartParam],
|
||||||
) -> ChatMessageParseResult:
|
) -> ChatMessageParseResult:
|
||||||
texts: List[str] = []
|
texts: List[str] = []
|
||||||
|
image_futures: List[Awaitable[ImagePixelData]] = []
|
||||||
|
|
||||||
for _, part in enumerate(parts):
|
vlm_config: Optional[VisionLanguageConfig] = getattr(
|
||||||
|
self.engine.engine, "vision_language_config", None)
|
||||||
|
model_config = getattr(self.engine.engine, "model_config", None)
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
part_type = part["type"]
|
part_type = part["type"]
|
||||||
if part_type == "text":
|
if part_type == "text":
|
||||||
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
||||||
|
|
||||||
texts.append(text)
|
texts.append(text)
|
||||||
|
elif part_type == "image_url":
|
||||||
|
if vlm_config is None:
|
||||||
|
raise ValueError(
|
||||||
|
"'image_url' input is not supported as the loaded "
|
||||||
|
"model is not multimodal.")
|
||||||
|
|
||||||
|
elif len(image_futures) == 0:
|
||||||
|
assert self.tokenizer is not None
|
||||||
|
image_url = cast(ChatCompletionContentPartImageParam,
|
||||||
|
part)["image_url"]
|
||||||
|
|
||||||
|
if image_url.get("detail", "auto") != "auto":
|
||||||
|
logger.warning(
|
||||||
|
"'image_url.detail' is currently not supported and "
|
||||||
|
"will be ignored.")
|
||||||
|
|
||||||
|
image_future = async_get_and_parse_image(image_url["url"])
|
||||||
|
image_futures.append(image_future)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Multiple 'image_url' input is currently not supported."
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||||
|
|
||||||
messages = [ConversationMessage(role=role, content="\n".join(texts))]
|
text_prompt = "\n".join(texts)
|
||||||
|
|
||||||
return ChatMessageParseResult(messages=messages)
|
if vlm_config is not None and len(image_futures):
|
||||||
|
|
||||||
|
(image_token_prompt,
|
||||||
|
image_token_str) = vlm_config.get_image_token_text(self.tokenizer)
|
||||||
|
|
||||||
|
# NOTE: If image token string (e.g, <image>) is already present
|
||||||
|
# in the text prompt, we assume it follows the same format required
|
||||||
|
# by the engine.
|
||||||
|
if image_token_str in text_prompt:
|
||||||
|
logger.warning(
|
||||||
|
"Detected image token string in the text prompt. "
|
||||||
|
"Skipping prompt formatting.")
|
||||||
|
messages = [
|
||||||
|
ConversationMessage(role=role, content=text_prompt)
|
||||||
|
]
|
||||||
|
|
||||||
|
else:
|
||||||
|
full_prompt = get_full_image_text_prompt(
|
||||||
|
image_prompt=image_token_prompt,
|
||||||
|
text_prompt=text_prompt,
|
||||||
|
config=model_config)
|
||||||
|
messages = [
|
||||||
|
ConversationMessage(role=role, content=full_prompt)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
messages = [ConversationMessage(role=role, content=text_prompt)]
|
||||||
|
|
||||||
|
return ChatMessageParseResult(messages=messages,
|
||||||
|
image_futures=image_futures)
|
||||||
|
|
||||||
def _parse_chat_message_content(
|
def _parse_chat_message_content(
|
||||||
self,
|
self,
|
||||||
@ -116,10 +180,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
content = message.get("content")
|
content = message.get("content")
|
||||||
|
|
||||||
if content is None:
|
if content is None:
|
||||||
return ChatMessageParseResult(messages=[])
|
return ChatMessageParseResult(messages=[], image_futures=[])
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
messages = [ConversationMessage(role=role, content=content)]
|
messages = [ConversationMessage(role=role, content=content)]
|
||||||
return ChatMessageParseResult(messages=messages)
|
return ChatMessageParseResult(messages=messages, image_futures=[])
|
||||||
|
|
||||||
return self._parse_chat_message_content_parts(role, content)
|
return self._parse_chat_message_content_parts(role, content)
|
||||||
|
|
||||||
@ -144,11 +208,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
conversation: List[ConversationMessage] = []
|
conversation: List[ConversationMessage] = []
|
||||||
|
image_futures: List[Awaitable[ImagePixelData]] = []
|
||||||
|
|
||||||
for msg in request.messages:
|
for msg in request.messages:
|
||||||
parsed_msg = self._parse_chat_message_content(msg)
|
chat_parsed_result = self._parse_chat_message_content(msg)
|
||||||
|
|
||||||
conversation.extend(parsed_msg.messages)
|
conversation.extend(chat_parsed_result.messages)
|
||||||
|
image_futures.extend(chat_parsed_result.image_futures)
|
||||||
|
|
||||||
prompt = self.tokenizer.apply_chat_template(
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
@ -159,6 +225,17 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
logger.error("Error in applying chat template from request: %s", e)
|
logger.error("Error in applying chat template from request: %s", e)
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
# Fetch image data
|
||||||
|
image_data: Optional[ImagePixelData] = None
|
||||||
|
try:
|
||||||
|
if len(image_futures):
|
||||||
|
# since we support only single image currently
|
||||||
|
assert len(image_futures) == 1
|
||||||
|
image_data = await image_futures[0]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in loading image data: %s", e)
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
request_id = f"cmpl-{random_uuid()}"
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
try:
|
try:
|
||||||
# Tokenize/detokenize depending on prompt format (string/token list)
|
# Tokenize/detokenize depending on prompt format (string/token list)
|
||||||
@ -183,11 +260,15 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
inputs: PromptInputs = {
|
||||||
|
"prompt": prompt_text,
|
||||||
|
"prompt_token_ids": prompt_ids,
|
||||||
|
}
|
||||||
|
if image_data is not None:
|
||||||
|
inputs["multi_modal_data"] = image_data
|
||||||
|
|
||||||
result_generator = self.engine.generate(
|
result_generator = self.engine.generate(
|
||||||
{
|
inputs,
|
||||||
"prompt": prompt_text,
|
|
||||||
"prompt_token_ids": prompt_ids
|
|
||||||
},
|
|
||||||
sampling_params,
|
sampling_params,
|
||||||
request_id,
|
request_id,
|
||||||
lora_request,
|
lora_request,
|
||||||
|
@ -29,6 +29,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_CPU_KVCACHE_SPACE: int = 0
|
VLLM_CPU_KVCACHE_SPACE: int = 0
|
||||||
VLLM_USE_RAY_COMPILED_DAG: bool = False
|
VLLM_USE_RAY_COMPILED_DAG: bool = False
|
||||||
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
|
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
|
||||||
|
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||||
VLLM_TARGET_DEVICE: str = "cuda"
|
VLLM_TARGET_DEVICE: str = "cuda"
|
||||||
MAX_JOBS: Optional[str] = None
|
MAX_JOBS: Optional[str] = None
|
||||||
NVCC_THREADS: Optional[str] = None
|
NVCC_THREADS: Optional[str] = None
|
||||||
@ -216,6 +217,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# Both spawn and fork work
|
# Both spawn and fork work
|
||||||
"VLLM_WORKER_MULTIPROC_METHOD":
|
"VLLM_WORKER_MULTIPROC_METHOD":
|
||||||
lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"),
|
lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"),
|
||||||
|
|
||||||
|
# Timeout for fetching images when serving multimodal models
|
||||||
|
# Default is 5 seconds
|
||||||
|
"VLLM_IMAGE_FETCH_TIMEOUT":
|
||||||
|
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
85
vllm/multimodal/utils.py
Normal file
85
vllm/multimodal/utils.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
|
||||||
|
from vllm.multimodal.image import ImagePixelData
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFetchAiohttp:
|
||||||
|
aiohttp_client: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_aiohttp_client(cls) -> aiohttp.ClientSession:
|
||||||
|
if cls.aiohttp_client is None:
|
||||||
|
timeout = aiohttp.ClientTimeout(total=VLLM_IMAGE_FETCH_TIMEOUT)
|
||||||
|
connector = aiohttp.TCPConnector()
|
||||||
|
cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
|
||||||
|
connector=connector)
|
||||||
|
|
||||||
|
return cls.aiohttp_client
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_image(cls, image_url: str) -> Image.Image:
|
||||||
|
"""Load PIL image from a url or base64 encoded openai GPT4V format"""
|
||||||
|
|
||||||
|
if image_url.startswith('http'):
|
||||||
|
# Avoid circular import
|
||||||
|
from vllm import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
|
client = cls.get_aiohttp_client()
|
||||||
|
headers = {"User-Agent": f"vLLM/{VLLM_VERSION}"}
|
||||||
|
|
||||||
|
async with client.get(url=image_url, headers=headers) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
image_raw = await response.read()
|
||||||
|
image = Image.open(BytesIO(image_raw))
|
||||||
|
|
||||||
|
# Only split once and assume the second part is the base64 encoded image
|
||||||
|
elif image_url.startswith('data:image'):
|
||||||
|
image = load_image_from_base64(image_url.split(',', 1)[1])
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid image url: A valid image url must start "
|
||||||
|
"with either 'data:image' or 'http'.")
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_and_parse_image(image_url: str) -> ImagePixelData:
|
||||||
|
with await ImageFetchAiohttp.fetch_image(image_url) as image:
|
||||||
|
return ImagePixelData(image)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
|
||||||
|
"""encode image to base64 format."""
|
||||||
|
|
||||||
|
buffered = BytesIO()
|
||||||
|
if format == 'JPEG':
|
||||||
|
image = image.convert('RGB')
|
||||||
|
image.save(buffered, format)
|
||||||
|
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
|
||||||
|
"""Load image from base64 format."""
|
||||||
|
return Image.open(BytesIO(base64.b64decode(image)))
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(ywang96): move this to a model registry for preprocessing vision
|
||||||
|
# language prompts based on the model type.
|
||||||
|
def get_full_image_text_prompt(image_prompt: str, text_prompt: str,
|
||||||
|
config: ModelConfig) -> str:
|
||||||
|
"""Combine image and text prompts for vision language model depending on
|
||||||
|
the model architecture."""
|
||||||
|
|
||||||
|
if config.hf_config.model_type == "llava":
|
||||||
|
full_prompt = f"{image_prompt}\n{text_prompt}"
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported model type: {config.hf_config.model_type}")
|
||||||
|
return full_prompt
|
Loading…
x
Reference in New Issue
Block a user