[Frontend] Add OpenAI Vision API Support (#5237)

Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Roger Wang 2024-06-07 11:23:32 -07:00 committed by GitHub
parent ca3ea51bde
commit 7a9cb294ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 653 additions and 19 deletions

View File

@ -3,7 +3,7 @@
Using VLMs
==========
This document shows you how to run and serve Vision Language Models (VLMs) using vLLM.
vLLM provides experimental support for Vision Language Models (VLMs). This document shows you how to run and serve these models using vLLM.
Engine Arguments
----------------
@ -54,3 +54,69 @@ For now, we only support a single image per text prompt. To pass an image to the
print(generated_text)
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_.
Online OpenAI Vision API Compatible Inference
----------------------------------------------
You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API <https://platform.openai.com/docs/guides/vision>`_.
.. note::
Currently, vLLM supports only **single** ``image_url`` input per ``messages``. Support for multi-image inputs will be
added in the future.
Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with vLLM API server.
.. important::
Since OpenAI Vision API is based on `Chat <https://platform.openai.com/docs/api-reference/chat>`_ API, a chat template
is **required** to launch the API server if the model's tokenizer does not come with one. In this example, we use the
HuggingFace Llava chat template that you can find in the example folder `here <https://github.com/vllm-project/vllm/blob/main/examples/template_llava.jinja>`_.
.. code-block:: bash
python -m vllm.entrypoints.openai.api_server \
--model llava-hf/llava-1.5-7b-hf \
--image-input-type pixel_values \
--image-token-id 32000 \
--image-input-shape 1,3,336,336 \
--image-feature-size 576 \
--chat-template template_llava.jinja
To consume the server, you can use the OpenAI client like in the example below:
.. code-block:: python
from openai import OpenAI
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
chat_response = client.chat.completions.create(
model="llava-hf/llava-1.5-7b-hf",
messages=[{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
},
},
],
}],
)
print("Chat response:", chat_response)
.. note::
By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable:
.. code-block:: shell
export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
.. note::
The prompt formatting with the image token ``<image>`` is not needed when serving VLMs with the API server since the prompt will be
processed automatically by the server.

View File

@ -30,6 +30,8 @@ Please see the [OpenAI API Reference](https://platform.openai.com/docs/api-refer
- Chat: `tools`, and `tool_choice`.
- Completions: `suffix`.
vLLM also provides experimental support for OpenAI Vision API compatible inference. See more details in [Using VLMs](../models/vlm.rst).
## Extra Parameters
vLLM supports a set of parameters that are not part of the OpenAI API.
In order to use them, you can pass them as extra parameters in the OpenAI client.

View File

@ -0,0 +1,23 @@
{%- if messages[0]['role'] == 'system' -%}
{%- set system_message = messages[0]['content'] -%}
{%- set messages = messages[1:] -%}
{%- else -%}
{% set system_message = '' -%}
{%- endif -%}
{{ bos_token + system_message }}
{%- for message in messages -%}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{%- endif -%}
{%- if message['role'] == 'user' -%}
{{ 'USER: ' + message['content'] + '\n' }}
{%- elif message['role'] == 'assistant' -%}
{{ 'ASSISTANT: ' + message['content'] + eos_token + '\n' }}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{ 'ASSISTANT:' }}
{% endif %}

View File

@ -0,0 +1,286 @@
from pathlib import Path
from typing import Dict
import openai
import pytest
import pytest_asyncio
import ray
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
from ..utils import ServerRunner
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
LLAVA_CHAT_TEMPLATE = (Path(__file__).parent.parent.parent /
"examples/template_llava.jinja")
assert LLAVA_CHAT_TEMPLATE.exists()
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
"https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
"https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
]
pytestmark = pytest.mark.openai
@pytest.fixture(scope="module")
def server():
ray.init()
server_runner = ServerRunner.remote([
"--model",
MODEL_NAME,
"--dtype",
"bfloat16",
"--max-model-len",
"4096",
"--enforce-eager",
"--image-input-type",
"pixel_values",
"--image-token-id",
"32000",
"--image-input-shape",
"1,3,336,336",
"--image-feature-size",
"576",
"--chat-template",
str(LLAVA_CHAT_TEMPLATE),
])
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()
@pytest.fixture(scope="session")
def client():
client = openai.AsyncOpenAI(
base_url="http://localhost:8000/v1",
api_key="token-abc123",
)
yield client
@pytest_asyncio.fixture(scope="session")
async def base64_encoded_image() -> Dict[str, str]:
return {
image_url:
encode_image_base64(await ImageFetchAiohttp.fetch_image(image_url))
for image_url in TEST_IMAGE_URLS
}
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image(server, client: openai.AsyncOpenAI,
model_name: str, image_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=596, total_tokens=606)
message = choice.message
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image_base64encoded(
server, client: openai.AsyncOpenAI, model_name: str, image_url: str,
base64_encoded_image: Dict[str, str]):
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url":
f"data:image/jpeg;base64,{base64_encoded_image[image_url]}"
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=596, total_tokens=606)
message = choice.message
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_chat_streaming_image(server, client: openai.AsyncOpenAI,
model_name: str, image_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
)
output = chat_completion.choices[0].message.content
stop_reason = chat_completion.choices[0].finish_reason
# test streaming
stream = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
stream=True,
)
chunks = []
finish_reason_count = 0
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert delta.role == "assistant"
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == stop_reason
assert delta.content
assert "".join(chunks) == output
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_multi_image_input(server, client: openai.AsyncOpenAI,
model_name: str, image_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]
with pytest.raises(openai.BadRequestError): # test multi-image input
await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
)
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -0,0 +1,75 @@
import base64
import mimetypes
from tempfile import NamedTemporaryFile
from typing import Dict, Tuple
import numpy as np
import pytest
import pytest_asyncio
from PIL import Image
from vllm.multimodal.utils import ImageFetchAiohttp
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
"https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
"https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
]
@pytest_asyncio.fixture(scope="session")
async def url_images() -> Dict[str, Image.Image]:
return {
image_url: await ImageFetchAiohttp.fetch_image(image_url)
for image_url in TEST_IMAGE_URLS
}
def get_supported_suffixes() -> Tuple[str, ...]:
# We should at least test the file types mentioned in GPT-4 with Vision
OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif')
# Additional file types that are supported by us
EXTRA_SUPPORTED_SUFFIXES = ('.bmp', '.tiff')
return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES
def _image_equals(a: Image.Image, b: Image.Image) -> bool:
return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
image_url: str, suffix: str):
url_image = url_images[image_url]
try:
mime_type = Image.MIME[Image.registered_extensions()[suffix]]
except KeyError:
try:
mime_type = mimetypes.types_map[suffix]
except KeyError:
pytest.skip('No MIME type')
with NamedTemporaryFile(suffix=suffix) as f:
try:
url_image.save(f.name)
except Exception as e:
if e.args[0] == 'cannot write mode RGBA as JPEG':
pytest.skip('Conversion not supported')
raise
base64_image = base64.b64encode(f.read()).decode("utf-8")
data_url = f"data:{mime_type};base64,{base64_image}"
data_image = await ImageFetchAiohttp.fetch_image(data_url)
if _image_equals(url_image, Image.open(f)):
assert _image_equals(url_image, data_image)
else:
pass # Lossy format; only check that image can be opened

View File

@ -5,7 +5,7 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
Union)
import torch
from transformers import PretrainedConfig
from transformers import PretrainedConfig, PreTrainedTokenizerBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -1119,6 +1119,16 @@ class VisionLanguageConfig:
f"Expecting to choose from "
f"{[x.name for x in cls.ImageInputType]}.") from e
#TODO(ywang96): make this a cached property once we refactor the
# VisionLanguageConfig class.
def get_image_token_text(
self, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]:
"""Get the image token placeholder text to be inserted into the
text prompt and the string representation of the image token id.
"""
image_token_str = tokenizer.decode(self.image_token_id)
return image_token_str * self.image_feature_size, image_token_str
def as_cli_args_dict(self) -> Dict[str, Any]:
"""Flatten vision language config to pure args.

View File

@ -1,15 +1,16 @@
import codecs
import time
from dataclasses import dataclass
from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List,
Optional)
from dataclasses import dataclass, field
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
List, Optional)
from typing import Sequence as GenericSequence
from typing import TypedDict, Union, cast, final
from fastapi import Request
from openai.types.chat import ChatCompletionContentPartTextParam
from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam)
from vllm.config import ModelConfig
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionContentPartParam, ChatCompletionLogProb,
@ -21,9 +22,13 @@ from vllm.entrypoints.openai.protocol import (
FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.multimodal.image import ImagePixelData
from vllm.multimodal.utils import (async_get_and_parse_image,
get_full_image_text_prompt)
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.utils import random_uuid
@ -40,6 +45,8 @@ class ConversationMessage(TypedDict):
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
image_futures: List[Awaitable[ImagePixelData]] = field(
default_factory=list)
class OpenAIServingChat(OpenAIServing):
@ -94,19 +101,76 @@ class OpenAIServingChat(OpenAIServing):
parts: Iterable[ChatCompletionContentPartParam],
) -> ChatMessageParseResult:
texts: List[str] = []
image_futures: List[Awaitable[ImagePixelData]] = []
for _, part in enumerate(parts):
vlm_config: Optional[VisionLanguageConfig] = getattr(
self.engine.engine, "vision_language_config", None)
model_config = getattr(self.engine.engine, "model_config", None)
for part in parts:
part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text)
elif part_type == "image_url":
if vlm_config is None:
raise ValueError(
"'image_url' input is not supported as the loaded "
"model is not multimodal.")
elif len(image_futures) == 0:
assert self.tokenizer is not None
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
image_future = async_get_and_parse_image(image_url["url"])
image_futures.append(image_future)
else:
raise NotImplementedError(
"Multiple 'image_url' input is currently not supported."
)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
messages = [ConversationMessage(role=role, content="\n".join(texts))]
text_prompt = "\n".join(texts)
return ChatMessageParseResult(messages=messages)
if vlm_config is not None and len(image_futures):
(image_token_prompt,
image_token_str) = vlm_config.get_image_token_text(self.tokenizer)
# NOTE: If image token string (e.g, <image>) is already present
# in the text prompt, we assume it follows the same format required
# by the engine.
if image_token_str in text_prompt:
logger.warning(
"Detected image token string in the text prompt. "
"Skipping prompt formatting.")
messages = [
ConversationMessage(role=role, content=text_prompt)
]
else:
full_prompt = get_full_image_text_prompt(
image_prompt=image_token_prompt,
text_prompt=text_prompt,
config=model_config)
messages = [
ConversationMessage(role=role, content=full_prompt)
]
else:
messages = [ConversationMessage(role=role, content=text_prompt)]
return ChatMessageParseResult(messages=messages,
image_futures=image_futures)
def _parse_chat_message_content(
self,
@ -116,10 +180,10 @@ class OpenAIServingChat(OpenAIServing):
content = message.get("content")
if content is None:
return ChatMessageParseResult(messages=[])
return ChatMessageParseResult(messages=[], image_futures=[])
if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages)
return ChatMessageParseResult(messages=messages, image_futures=[])
return self._parse_chat_message_content_parts(role, content)
@ -144,11 +208,13 @@ class OpenAIServingChat(OpenAIServing):
try:
conversation: List[ConversationMessage] = []
image_futures: List[Awaitable[ImagePixelData]] = []
for msg in request.messages:
parsed_msg = self._parse_chat_message_content(msg)
chat_parsed_result = self._parse_chat_message_content(msg)
conversation.extend(parsed_msg.messages)
conversation.extend(chat_parsed_result.messages)
image_futures.extend(chat_parsed_result.image_futures)
prompt = self.tokenizer.apply_chat_template(
conversation=conversation,
@ -159,6 +225,17 @@ class OpenAIServingChat(OpenAIServing):
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
# Fetch image data
image_data: Optional[ImagePixelData] = None
try:
if len(image_futures):
# since we support only single image currently
assert len(image_futures) == 1
image_data = await image_futures[0]
except Exception as e:
logger.error("Error in loading image data: %s", e)
return self.create_error_response(str(e))
request_id = f"cmpl-{random_uuid()}"
try:
# Tokenize/detokenize depending on prompt format (string/token list)
@ -183,11 +260,15 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e:
return self.create_error_response(str(e))
inputs: PromptInputs = {
"prompt": prompt_text,
"prompt_token_ids": prompt_ids,
}
if image_data is not None:
inputs["multi_modal_data"] = image_data
result_generator = self.engine.generate(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
inputs,
sampling_params,
request_id,
lora_request,

View File

@ -29,6 +29,7 @@ if TYPE_CHECKING:
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
@ -216,6 +217,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Both spawn and fork work
"VLLM_WORKER_MULTIPROC_METHOD":
lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"),
# Timeout for fetching images when serving multimodal models
# Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
}
# end-env-vars-definition

85
vllm/multimodal/utils.py Normal file
View File

@ -0,0 +1,85 @@
import base64
from io import BytesIO
from typing import Optional, Union
import aiohttp
from PIL import Image
from vllm.config import ModelConfig
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from vllm.multimodal.image import ImagePixelData
class ImageFetchAiohttp:
aiohttp_client: Optional[aiohttp.ClientSession] = None
@classmethod
def get_aiohttp_client(cls) -> aiohttp.ClientSession:
if cls.aiohttp_client is None:
timeout = aiohttp.ClientTimeout(total=VLLM_IMAGE_FETCH_TIMEOUT)
connector = aiohttp.TCPConnector()
cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
connector=connector)
return cls.aiohttp_client
@classmethod
async def fetch_image(cls, image_url: str) -> Image.Image:
"""Load PIL image from a url or base64 encoded openai GPT4V format"""
if image_url.startswith('http'):
# Avoid circular import
from vllm import __version__ as VLLM_VERSION
client = cls.get_aiohttp_client()
headers = {"User-Agent": f"vLLM/{VLLM_VERSION}"}
async with client.get(url=image_url, headers=headers) as response:
response.raise_for_status()
image_raw = await response.read()
image = Image.open(BytesIO(image_raw))
# Only split once and assume the second part is the base64 encoded image
elif image_url.startswith('data:image'):
image = load_image_from_base64(image_url.split(',', 1)[1])
else:
raise ValueError("Invalid image url: A valid image url must start "
"with either 'data:image' or 'http'.")
return image
async def async_get_and_parse_image(image_url: str) -> ImagePixelData:
with await ImageFetchAiohttp.fetch_image(image_url) as image:
return ImagePixelData(image)
def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
"""encode image to base64 format."""
buffered = BytesIO()
if format == 'JPEG':
image = image.convert('RGB')
image.save(buffered, format)
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
"""Load image from base64 format."""
return Image.open(BytesIO(base64.b64decode(image)))
# TODO(ywang96): move this to a model registry for preprocessing vision
# language prompts based on the model type.
def get_full_image_text_prompt(image_prompt: str, text_prompt: str,
config: ModelConfig) -> str:
"""Combine image and text prompts for vision language model depending on
the model architecture."""
if config.hf_config.model_type == "llava":
full_prompt = f"{image_prompt}\n{text_prompt}"
else:
raise ValueError(
f"Unsupported model type: {config.hf_config.model_type}")
return full_prompt