[Frontend] support image embeds (#13955)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
60a98b2de5
commit
b0746fae3d
@ -462,4 +462,69 @@ export VLLM_AUDIO_FETCH_TIMEOUT=<timeout>
|
||||
|
||||
### Embedding Inputs
|
||||
|
||||
TBD
|
||||
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
|
||||
pass a tensor of shape to the corresponding field of the multi-modal dictionary.
|
||||
#### Image Embedding Inputs
|
||||
For image embeddings, you can pass the base64-encoded tensor to the `image_embeds` field.
|
||||
The following example demonstrates how to pass image embeddings to the OpenAI server:
|
||||
|
||||
```python
|
||||
image_embedding = torch.load(...)
|
||||
grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.save(image_embedding, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
base64_image_embedding = base64.b64encode(binary_data).decode('utf-8')
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
# Basic usage - this is equivalent to the LLaVA example for offline inference
|
||||
model = "llava-hf/llava-1.5-7b-hf"
|
||||
embeds = {
|
||||
"type": "image_embeds",
|
||||
"image_embeds": f"{base64_image_embedding}"
|
||||
}
|
||||
|
||||
# Pass additional parameters (available to Qwen2-VL and MiniCPM-V)
|
||||
model = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
embeds = {
|
||||
"type": "image_embeds",
|
||||
"image_embeds": {
|
||||
"image_embeds": f"{base64_image_embedding}" , # Required
|
||||
"image_grid_thw": f"{base64_image_grid_thw}" # Required by Qwen/Qwen2-VL-2B-Instruct
|
||||
},
|
||||
}
|
||||
model = "openbmb/MiniCPM-V-2_6"
|
||||
embeds = {
|
||||
"type": "image_embeds",
|
||||
"image_embeds": {
|
||||
"image_embeds": f"{base64_image_embedding}" , # Required
|
||||
"image_sizes": f"{base64_image_sizes}" # Required by openbmb/MiniCPM-V-2_6
|
||||
},
|
||||
}
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?",
|
||||
},
|
||||
embeds,
|
||||
],
|
||||
},
|
||||
],
|
||||
model=model,
|
||||
)
|
||||
```
|
||||
|
||||
:::{note}
|
||||
Only one message can contain `{"type": "image_embeds"}`.
|
||||
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc.
|
||||
:::
|
||||
|
@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
|
||||
image_embeds: Required[Union[str, dict[str, str]]]
|
||||
"""
|
||||
The image embeddings. It can be either:
|
||||
- A single base64 string.
|
||||
- A dictionary where each value is a base64 string.
|
||||
"""
|
||||
type: Required[Literal["image_embeds"]]
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
class VideoURL(TypedDict, total=False):
|
||||
url: Required[str]
|
||||
"""
|
||||
@ -109,6 +120,7 @@ ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
ChatCompletionContentPartInputAudioParam,
|
||||
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
|
||||
CustomChatCompletionContentSimpleImageParam,
|
||||
ChatCompletionContentPartImageEmbedsParam,
|
||||
CustomChatCompletionContentSimpleAudioParam,
|
||||
CustomChatCompletionContentSimpleVideoParam, str]
|
||||
|
||||
@ -350,7 +362,7 @@ def resolve_chat_template_content_format(
|
||||
return detected_format
|
||||
|
||||
|
||||
ModalityStr = Literal["image", "audio", "video"]
|
||||
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@ -391,7 +403,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
hf_config = self._model_config.hf_config
|
||||
model_type = hf_config.model_type
|
||||
|
||||
if modality == "image":
|
||||
if modality in ["image", "image_embeds"]:
|
||||
if model_type == "phi3_v":
|
||||
# Workaround since this token is not defined in the tokenizer
|
||||
return f"<|image_{current_count}|>"
|
||||
@ -470,10 +482,27 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||
|
||||
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
if self._items_by_modality:
|
||||
return dict(self._items_by_modality)
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_inputs = {}
|
||||
items_by_modality = dict(self._items_by_modality)
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError(\
|
||||
"Mixing raw image and embedding inputs is not allowed")
|
||||
|
||||
return None
|
||||
if "image_embeds" in items_by_modality:
|
||||
image_embeds_lst = items_by_modality["image_embeds"]
|
||||
if len(image_embeds_lst) > 1:
|
||||
raise ValueError(\
|
||||
"Only one message can have {'type': 'image_embeds'}")
|
||||
mm_inputs["image"] = image_embeds_lst[0]
|
||||
elif "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
elif "audio" in items_by_modality:
|
||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||
elif "video" in items_by_modality:
|
||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||
return mm_inputs
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return MultiModalContentParser(self)
|
||||
@ -482,13 +511,31 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||
|
||||
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
if self._items_by_modality:
|
||||
return {
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_inputs = {}
|
||||
items_by_modality = {
|
||||
modality: await asyncio.gather(*items)
|
||||
for modality, items in self._items_by_modality.items()
|
||||
}
|
||||
|
||||
return None
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError(
|
||||
"Mixing raw image and embedding inputs is not allowed")
|
||||
|
||||
if "image_embeds" in items_by_modality:
|
||||
image_embeds_lst = items_by_modality["image_embeds"]
|
||||
if len(image_embeds_lst) > 1:
|
||||
raise ValueError(
|
||||
"Only one message can have {'type': 'image_embeds'}")
|
||||
mm_inputs["image"] = image_embeds_lst[0]
|
||||
elif "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
elif "audio" in items_by_modality:
|
||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||
elif "video" in items_by_modality:
|
||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||
return mm_inputs
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return AsyncMultiModalContentParser(self)
|
||||
@ -513,6 +560,11 @@ class BaseMultiModalContentParser(ABC):
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_image_embeds(self,
|
||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
@ -543,6 +595,21 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
placeholder = self._tracker.add("image", image)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_image_embeds(self,
|
||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||
if isinstance(image_embeds, dict):
|
||||
embeds = {
|
||||
k: self._connector.fetch_image_embedding(v)
|
||||
for k, v in image_embeds.items()
|
||||
}
|
||||
placeholder = self._tracker.add("image_embeds", embeds)
|
||||
|
||||
if isinstance(image_embeds, str):
|
||||
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||
placeholder = self._tracker.add("image_embeds", embedding)
|
||||
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio = self._connector.fetch_audio(audio_url)
|
||||
|
||||
@ -579,6 +646,25 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
placeholder = self._tracker.add("image", image_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_image_embeds(self,
|
||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
|
||||
|
||||
if isinstance(image_embeds, dict):
|
||||
embeds = {
|
||||
k: self._connector.fetch_image_embedding(v)
|
||||
for k, v in image_embeds.items()
|
||||
}
|
||||
future.set_result(embeds)
|
||||
|
||||
if isinstance(image_embeds, str):
|
||||
embedding = self._connector.\
|
||||
fetch_image_embedding(image_embeds)
|
||||
future.set_result(embedding)
|
||||
|
||||
placeholder = self._tracker.add("image_embeds", future)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio_coro = self._connector.fetch_audio_async(audio_url)
|
||||
|
||||
@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
|
||||
# No need to validate using Pydantic again
|
||||
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
|
||||
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
|
||||
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
@ -700,6 +787,8 @@ MM_PARSER_MAP: dict[
|
||||
lambda part: _TextParser(part).get("text", ""),
|
||||
"image_url":
|
||||
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
|
||||
"image_embeds":
|
||||
lambda part: _ImageEmbedsParser(part).get("image_embeds", {}),
|
||||
"audio_url":
|
||||
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
|
||||
"input_audio":
|
||||
@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part(
|
||||
|
||||
|
||||
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
|
||||
"image_embeds",
|
||||
"audio_url", "input_audio", "video_url")
|
||||
|
||||
|
||||
@ -843,7 +933,10 @@ def _parse_chat_message_content_part(
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_image(str_content)
|
||||
return {'type': 'image'} if wrap_dicts else None
|
||||
|
||||
if part_type == "image_embeds":
|
||||
content = cast(Union[str, dict[str, str]], content)
|
||||
mm_parser.parse_image_embeds(content)
|
||||
return {'type': 'image'} if wrap_dicts else None
|
||||
if part_type == "audio_url":
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_audio(str_content)
|
||||
|
@ -134,3 +134,22 @@ class ImageMediaIO(MediaIO[Image.Image]):
|
||||
data = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(data).decode('utf-8')
|
||||
|
||||
|
||||
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||
buffer = BytesIO(data)
|
||||
return torch.load(buffer, weights_only=True)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||
return torch.load(filepath)
|
||||
|
||||
def encode_base64(self, media: torch.Tensor) -> str:
|
||||
return base64.b64encode(media.numpy()).decode('utf-8')
|
||||
|
@ -7,6 +7,7 @@ from urllib.parse import ParseResult, urlparse
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -16,7 +17,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .audio import AudioMediaIO
|
||||
from .base import MediaIO
|
||||
from .image import ImageMediaIO
|
||||
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
||||
from .inputs import PlaceholderRange
|
||||
from .video import VideoMediaIO
|
||||
|
||||
@ -245,6 +246,17 @@ class MediaConnector:
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_image_embedding(
|
||||
self,
|
||||
data: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Load image embedding from a URL.
|
||||
"""
|
||||
image_embedding_io = ImageEmbeddingMediaIO()
|
||||
|
||||
return image_embedding_io.load_base64("", data)
|
||||
|
||||
|
||||
global_media_connector = MediaConnector()
|
||||
"""The global :class:`MediaConnector` instance used by vLLM."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user