[Frontend] Multi-Modality Support for Loading Local Image Files (#9915)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey 2024-11-04 23:34:57 +08:00 committed by GitHub
parent ccb5376a9a
commit ac6b8f19b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 132 additions and 14 deletions

View File

@ -1,11 +1,12 @@
import base64
import mimetypes
from tempfile import NamedTemporaryFile
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Dict, Tuple
import numpy as np
import pytest
from PIL import Image
from PIL import Image, ImageChops
from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import (async_fetch_image, fetch_image,
@ -84,6 +85,40 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
assert _image_equals(data_image_sync, data_image_async)
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_local_files(image_url: str):
with TemporaryDirectory() as temp_dir:
origin_image = fetch_image(image_url)
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
quality=100,
icc_profile=origin_image.info.get('icc_profile'))
image_async = await async_fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
image_sync = fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
# Check that the images are equal
assert not ImageChops.difference(image_sync, image_async).getbbox()
with pytest.raises(ValueError):
await async_fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
with pytest.raises(ValueError):
await async_fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(ValueError):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
with pytest.raises(ValueError):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
def test_repeat_and_pad_placeholder_tokens(model):
config = AutoConfig.from_pretrained(model)

View File

@ -55,6 +55,10 @@ class ModelConfig:
"mistral" will always use the tokenizer from `mistral_common`.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
@ -134,6 +138,7 @@ class ModelConfig:
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
@ -164,6 +169,7 @@ class ModelConfig:
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.allowed_local_media_path = allowed_local_media_path
self.seed = seed
self.revision = revision
self.code_revision = code_revision
@ -1319,6 +1325,8 @@ class SpeculativeConfig:
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
allowed_local_media_path=target_model_config.
allowed_local_media_path,
dtype=target_model_config.dtype,
seed=target_model_config.seed,
revision=draft_revision,

View File

@ -92,6 +92,7 @@ class EngineArgs:
tokenizer_mode: str = 'auto'
chat_template_text_format: str = 'string'
trust_remote_code: bool = False
allowed_local_media_path: str = ""
download_dir: Optional[str] = None
load_format: str = 'auto'
config_format: ConfigFormat = ConfigFormat.AUTO
@ -269,6 +270,13 @@ class EngineArgs:
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
parser.add_argument(
'--allowed-local-media-path',
type=str,
help="Allowing API requests to read local images or videos"
"from directories specified by the server file system."
"This is a security risk."
"Should only be enabled in trusted environments")
parser.add_argument('--download-dir',
type=nullable_str,
default=EngineArgs.download_dir,
@ -920,6 +928,7 @@ class EngineArgs:
tokenizer_mode=self.tokenizer_mode,
chat_template_text_format=self.chat_template_text_format,
trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path,
dtype=self.dtype,
seed=self.seed,
revision=self.revision,

View File

@ -307,7 +307,9 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._tracker = tracker
def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url)
image = get_and_parse_image(image_url,
allowed_local_media_path=self._tracker.
_model_config.allowed_local_media_path)
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
@ -327,7 +329,10 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._tracker = tracker
def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image(image_url)
image_coro = async_get_and_parse_image(
image_url,
allowed_local_media_path=self._tracker._model_config.
allowed_local_media_path)
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)

View File

@ -58,6 +58,10 @@ class LLM:
from the input.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images
or videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
@ -139,6 +143,7 @@ class LLM:
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
@ -179,6 +184,7 @@ class LLM:
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,

View File

@ -1,4 +1,5 @@
import base64
import os
from functools import lru_cache
from io import BytesIO
from typing import Any, List, Optional, Tuple, TypeVar, Union
@ -18,19 +19,60 @@ logger = init_logger(__name__)
cached_get_tokenizer = lru_cache(get_tokenizer)
def _load_image_from_bytes(b: bytes):
def _load_image_from_bytes(b: bytes) -> Image.Image:
image = Image.open(BytesIO(b))
image.load()
return image
def _load_image_from_data_url(image_url: str):
def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool:
# Get the common path
common_path = os.path.commonpath([
os.path.abspath(image_path),
os.path.abspath(allowed_local_media_path)
])
# Check if the common path is the same as allowed_local_media_path
return common_path == os.path.abspath(allowed_local_media_path)
def _load_image_from_file(image_url: str,
allowed_local_media_path: str) -> Image.Image:
if not allowed_local_media_path:
raise ValueError("Invalid 'image_url': Cannot load local files without"
"'--allowed-local-media-path'.")
if allowed_local_media_path:
if not os.path.exists(allowed_local_media_path):
raise ValueError(
"Invalid '--allowed-local-media-path': "
f"The path {allowed_local_media_path} does not exist.")
if not os.path.isdir(allowed_local_media_path):
raise ValueError(
"Invalid '--allowed-local-media-path': "
f"The path {allowed_local_media_path} must be a directory.")
# Only split once and assume the second part is the image path
_, image_path = image_url.split("file://", 1)
if not _is_subpath(image_path, allowed_local_media_path):
raise ValueError(
f"Invalid 'image_url': The file path {image_path} must"
" be a subpath of '--allowed-local-media-path'"
f" '{allowed_local_media_path}'.")
image = Image.open(image_path)
image.load()
return image
def _load_image_from_data_url(image_url: str) -> Image.Image:
# Only split once and assume the second part is the base64 encoded image
_, image_base64 = image_url.split(",", 1)
return load_image_from_base64(image_base64)
def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
def fetch_image(image_url: str,
*,
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Load a PIL image from a HTTP or base64 data URL.
@ -43,16 +85,19 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
"with either 'data:image', 'file://' or 'http'.")
return image.convert(image_mode)
async def async_fetch_image(image_url: str,
*,
image_mode: str = "RGB") -> Image.Image:
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
@ -65,9 +110,11 @@ async def async_fetch_image(image_url: str,
elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
"with either 'data:image', 'file://' or 'http'.")
return image.convert(image_mode)
@ -126,8 +173,12 @@ def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
return {"audio": (audio, sr)}
def get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = fetch_image(image_url)
def get_and_parse_image(
image_url: str,
*,
allowed_local_media_path: str = "") -> MultiModalDataDict:
image = fetch_image(image_url,
allowed_local_media_path=allowed_local_media_path)
return {"image": image}
@ -136,8 +187,12 @@ async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
return {"audio": (audio, sr)}
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await async_fetch_image(image_url)
async def async_get_and_parse_image(
image_url: str,
*,
allowed_local_media_path: str = "") -> MultiModalDataDict:
image = await async_fetch_image(
image_url, allowed_local_media_path=allowed_local_media_path)
return {"image": image}