[Frontend] Multi-Modality Support for Loading Local Image Files (#9915)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
ccb5376a9a
commit
ac6b8f19b9
@ -1,11 +1,12 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
from tempfile import NamedTemporaryFile
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageChops
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from vllm.multimodal.utils import (async_fetch_image, fetch_image,
|
||||
@ -84,6 +85,40 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
||||
assert _image_equals(data_image_sync, data_image_async)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
async def test_fetch_image_local_files(image_url: str):
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
origin_image = fetch_image(image_url)
|
||||
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
|
||||
quality=100,
|
||||
icc_profile=origin_image.info.get('icc_profile'))
|
||||
|
||||
image_async = await async_fetch_image(
|
||||
f"file://{temp_dir}/{os.path.basename(image_url)}",
|
||||
allowed_local_media_path=temp_dir)
|
||||
|
||||
image_sync = fetch_image(
|
||||
f"file://{temp_dir}/{os.path.basename(image_url)}",
|
||||
allowed_local_media_path=temp_dir)
|
||||
# Check that the images are equal
|
||||
assert not ImageChops.difference(image_sync, image_async).getbbox()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_fetch_image(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}",
|
||||
allowed_local_media_path=temp_dir)
|
||||
with pytest.raises(ValueError):
|
||||
await async_fetch_image(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}",
|
||||
allowed_local_media_path=temp_dir)
|
||||
with pytest.raises(ValueError):
|
||||
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
def test_repeat_and_pad_placeholder_tokens(model):
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
|
@ -55,6 +55,10 @@ class ModelConfig:
|
||||
"mistral" will always use the tokenizer from `mistral_common`.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
allowed_local_media_path: Allowing API requests to read local images or
|
||||
videos from directories specified by the server file system.
|
||||
This is a security risk. Should only be enabled in trusted
|
||||
environments.
|
||||
dtype: Data type for model weights and activations. The "auto" option
|
||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||
for BF16 models.
|
||||
@ -134,6 +138,7 @@ class ModelConfig:
|
||||
trust_remote_code: bool,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
allowed_local_media_path: str = "",
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
@ -164,6 +169,7 @@ class ModelConfig:
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.allowed_local_media_path = allowed_local_media_path
|
||||
self.seed = seed
|
||||
self.revision = revision
|
||||
self.code_revision = code_revision
|
||||
@ -1319,6 +1325,8 @@ class SpeculativeConfig:
|
||||
tokenizer=target_model_config.tokenizer,
|
||||
tokenizer_mode=target_model_config.tokenizer_mode,
|
||||
trust_remote_code=target_model_config.trust_remote_code,
|
||||
allowed_local_media_path=target_model_config.
|
||||
allowed_local_media_path,
|
||||
dtype=target_model_config.dtype,
|
||||
seed=target_model_config.seed,
|
||||
revision=draft_revision,
|
||||
|
@ -92,6 +92,7 @@ class EngineArgs:
|
||||
tokenizer_mode: str = 'auto'
|
||||
chat_template_text_format: str = 'string'
|
||||
trust_remote_code: bool = False
|
||||
allowed_local_media_path: str = ""
|
||||
download_dir: Optional[str] = None
|
||||
load_format: str = 'auto'
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO
|
||||
@ -269,6 +270,13 @@ class EngineArgs:
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='Trust remote code from huggingface.')
|
||||
parser.add_argument(
|
||||
'--allowed-local-media-path',
|
||||
type=str,
|
||||
help="Allowing API requests to read local images or videos"
|
||||
"from directories specified by the server file system."
|
||||
"This is a security risk."
|
||||
"Should only be enabled in trusted environments")
|
||||
parser.add_argument('--download-dir',
|
||||
type=nullable_str,
|
||||
default=EngineArgs.download_dir,
|
||||
@ -920,6 +928,7 @@ class EngineArgs:
|
||||
tokenizer_mode=self.tokenizer_mode,
|
||||
chat_template_text_format=self.chat_template_text_format,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
allowed_local_media_path=self.allowed_local_media_path,
|
||||
dtype=self.dtype,
|
||||
seed=self.seed,
|
||||
revision=self.revision,
|
||||
|
@ -307,7 +307,9 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._tracker = tracker
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
image = get_and_parse_image(image_url)
|
||||
image = get_and_parse_image(image_url,
|
||||
allowed_local_media_path=self._tracker.
|
||||
_model_config.allowed_local_media_path)
|
||||
|
||||
placeholder = self._tracker.add("image", image)
|
||||
self._add_placeholder(placeholder)
|
||||
@ -327,7 +329,10 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._tracker = tracker
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
image_coro = async_get_and_parse_image(image_url)
|
||||
image_coro = async_get_and_parse_image(
|
||||
image_url,
|
||||
allowed_local_media_path=self._tracker._model_config.
|
||||
allowed_local_media_path)
|
||||
|
||||
placeholder = self._tracker.add("image", image_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
@ -58,6 +58,10 @@ class LLM:
|
||||
from the input.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
allowed_local_media_path: Allowing API requests to read local images
|
||||
or videos from directories specified by the server file system.
|
||||
This is a security risk. Should only be enabled in trusted
|
||||
environments.
|
||||
tensor_parallel_size: The number of GPUs to use for distributed
|
||||
execution with tensor parallelism.
|
||||
dtype: The data type for the model weights and activations. Currently,
|
||||
@ -139,6 +143,7 @@ class LLM:
|
||||
tokenizer_mode: str = "auto",
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
allowed_local_media_path: str = "",
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
@ -179,6 +184,7 @@ class LLM:
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
trust_remote_code=trust_remote_code,
|
||||
allowed_local_media_path=allowed_local_media_path,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
|
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from io import BytesIO
|
||||
from typing import Any, List, Optional, Tuple, TypeVar, Union
|
||||
@ -18,19 +19,60 @@ logger = init_logger(__name__)
|
||||
cached_get_tokenizer = lru_cache(get_tokenizer)
|
||||
|
||||
|
||||
def _load_image_from_bytes(b: bytes):
|
||||
def _load_image_from_bytes(b: bytes) -> Image.Image:
|
||||
image = Image.open(BytesIO(b))
|
||||
image.load()
|
||||
return image
|
||||
|
||||
|
||||
def _load_image_from_data_url(image_url: str):
|
||||
def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool:
|
||||
# Get the common path
|
||||
common_path = os.path.commonpath([
|
||||
os.path.abspath(image_path),
|
||||
os.path.abspath(allowed_local_media_path)
|
||||
])
|
||||
# Check if the common path is the same as allowed_local_media_path
|
||||
return common_path == os.path.abspath(allowed_local_media_path)
|
||||
|
||||
|
||||
def _load_image_from_file(image_url: str,
|
||||
allowed_local_media_path: str) -> Image.Image:
|
||||
if not allowed_local_media_path:
|
||||
raise ValueError("Invalid 'image_url': Cannot load local files without"
|
||||
"'--allowed-local-media-path'.")
|
||||
if allowed_local_media_path:
|
||||
if not os.path.exists(allowed_local_media_path):
|
||||
raise ValueError(
|
||||
"Invalid '--allowed-local-media-path': "
|
||||
f"The path {allowed_local_media_path} does not exist.")
|
||||
if not os.path.isdir(allowed_local_media_path):
|
||||
raise ValueError(
|
||||
"Invalid '--allowed-local-media-path': "
|
||||
f"The path {allowed_local_media_path} must be a directory.")
|
||||
|
||||
# Only split once and assume the second part is the image path
|
||||
_, image_path = image_url.split("file://", 1)
|
||||
if not _is_subpath(image_path, allowed_local_media_path):
|
||||
raise ValueError(
|
||||
f"Invalid 'image_url': The file path {image_path} must"
|
||||
" be a subpath of '--allowed-local-media-path'"
|
||||
f" '{allowed_local_media_path}'.")
|
||||
|
||||
image = Image.open(image_path)
|
||||
image.load()
|
||||
return image
|
||||
|
||||
|
||||
def _load_image_from_data_url(image_url: str) -> Image.Image:
|
||||
# Only split once and assume the second part is the base64 encoded image
|
||||
_, image_base64 = image_url.split(",", 1)
|
||||
return load_image_from_base64(image_base64)
|
||||
|
||||
|
||||
def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
|
||||
def fetch_image(image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
allowed_local_media_path: str = "") -> Image.Image:
|
||||
"""
|
||||
Load a PIL image from a HTTP or base64 data URL.
|
||||
|
||||
@ -43,16 +85,19 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
|
||||
|
||||
elif image_url.startswith('data:image'):
|
||||
image = _load_image_from_data_url(image_url)
|
||||
elif image_url.startswith('file://'):
|
||||
image = _load_image_from_file(image_url, allowed_local_media_path)
|
||||
else:
|
||||
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
|
||||
"with either 'data:image' or 'http'.")
|
||||
"with either 'data:image', 'file://' or 'http'.")
|
||||
|
||||
return image.convert(image_mode)
|
||||
|
||||
|
||||
async def async_fetch_image(image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB") -> Image.Image:
|
||||
image_mode: str = "RGB",
|
||||
allowed_local_media_path: str = "") -> Image.Image:
|
||||
"""
|
||||
Asynchronously load a PIL image from a HTTP or base64 data URL.
|
||||
|
||||
@ -65,9 +110,11 @@ async def async_fetch_image(image_url: str,
|
||||
|
||||
elif image_url.startswith('data:image'):
|
||||
image = _load_image_from_data_url(image_url)
|
||||
elif image_url.startswith('file://'):
|
||||
image = _load_image_from_file(image_url, allowed_local_media_path)
|
||||
else:
|
||||
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
|
||||
"with either 'data:image' or 'http'.")
|
||||
"with either 'data:image', 'file://' or 'http'.")
|
||||
|
||||
return image.convert(image_mode)
|
||||
|
||||
@ -126,8 +173,12 @@ def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
|
||||
return {"audio": (audio, sr)}
|
||||
|
||||
|
||||
def get_and_parse_image(image_url: str) -> MultiModalDataDict:
|
||||
image = fetch_image(image_url)
|
||||
def get_and_parse_image(
|
||||
image_url: str,
|
||||
*,
|
||||
allowed_local_media_path: str = "") -> MultiModalDataDict:
|
||||
image = fetch_image(image_url,
|
||||
allowed_local_media_path=allowed_local_media_path)
|
||||
return {"image": image}
|
||||
|
||||
|
||||
@ -136,8 +187,12 @@ async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
|
||||
return {"audio": (audio, sr)}
|
||||
|
||||
|
||||
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
|
||||
image = await async_fetch_image(image_url)
|
||||
async def async_get_and_parse_image(
|
||||
image_url: str,
|
||||
*,
|
||||
allowed_local_media_path: str = "") -> MultiModalDataDict:
|
||||
image = await async_fetch_image(
|
||||
image_url, allowed_local_media_path=allowed_local_media_path)
|
||||
return {"image": image}
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user