[Frontend] support image embeds (#13955)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey 2025-03-10 20:36:03 +08:00 committed by GitHub
parent 60a98b2de5
commit b0746fae3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 200 additions and 11 deletions

View File

@ -462,4 +462,69 @@ export VLLM_AUDIO_FETCH_TIMEOUT=<timeout>
### Embedding Inputs
TBD
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape to the corresponding field of the multi-modal dictionary.
#### Image Embedding Inputs
For image embeddings, you can pass the base64-encoded tensor to the `image_embeds` field.
The following example demonstrates how to pass image embeddings to the OpenAI server:
```python
image_embedding = torch.load(...)
grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct
buffer = io.BytesIO()
torch.save(image_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_image_embedding = base64.b64encode(binary_data).decode('utf-8')
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
# Basic usage - this is equivalent to the LLaVA example for offline inference
model = "llava-hf/llava-1.5-7b-hf"
embeds = {
"type": "image_embeds",
"image_embeds": f"{base64_image_embedding}"
}
# Pass additional parameters (available to Qwen2-VL and MiniCPM-V)
model = "Qwen/Qwen2-VL-2B-Instruct"
embeds = {
"type": "image_embeds",
"image_embeds": {
"image_embeds": f"{base64_image_embedding}" , # Required
"image_grid_thw": f"{base64_image_grid_thw}" # Required by Qwen/Qwen2-VL-2B-Instruct
},
}
model = "openbmb/MiniCPM-V-2_6"
embeds = {
"type": "image_embeds",
"image_embeds": {
"image_embeds": f"{base64_image_embedding}" , # Required
"image_sizes": f"{base64_image_sizes}" # Required by openbmb/MiniCPM-V-2_6
},
}
chat_completion = client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{
"type": "text",
"text": "What's in this image?",
},
embeds,
],
},
],
model=model,
)
```
:::{note}
Only one message can contain `{"type": "image_embeds"}`.
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc.
:::

View File

@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
"""The type of the content part."""
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
image_embeds: Required[Union[str, dict[str, str]]]
"""
The image embeddings. It can be either:
- A single base64 string.
- A dictionary where each value is a base64 string.
"""
type: Required[Literal["image_embeds"]]
"""The type of the content part."""
class VideoURL(TypedDict, total=False):
url: Required[str]
"""
@ -109,6 +120,7 @@ ChatCompletionContentPartParam: TypeAlias = Union[
ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam,
CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str]
@ -350,7 +362,7 @@ def resolve_chat_template_content_format(
return detected_format
ModalityStr = Literal["image", "audio", "video"]
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
_T = TypeVar("_T")
@ -391,7 +403,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config = self._model_config.hf_config
model_type = hf_config.model_type
if modality == "image":
if modality in ["image", "image_embeds"]:
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>"
@ -470,10 +482,27 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items_by_modality:
return dict(self._items_by_modality)
if not self._items_by_modality:
return None
mm_inputs = {}
items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(\
"Mixing raw image and embedding inputs is not allowed")
return None
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError(\
"Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
elif "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
elif "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
elif "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)
@ -482,13 +511,31 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items_by_modality:
return {
if not self._items_by_modality:
return None
mm_inputs = {}
items_by_modality = {
modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items()
}
return None
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(
"Mixing raw image and embedding inputs is not allowed")
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError(
"Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
elif "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
elif "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
elif "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser":
return AsyncMultiModalContentParser(self)
@ -513,6 +560,11 @@ class BaseMultiModalContentParser(ABC):
def parse_image(self, image_url: str) -> None:
raise NotImplementedError
@abstractmethod
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
raise NotImplementedError
@abstractmethod
def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError
@ -543,6 +595,21 @@ class MultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items()
}
placeholder = self._tracker.add("image_embeds", embeds)
if isinstance(image_embeds, str):
embedding = self._connector.fetch_image_embedding(image_embeds)
placeholder = self._tracker.add("image_embeds", embedding)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio = self._connector.fetch_audio(audio_url)
@ -579,6 +646,25 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items()
}
future.set_result(embeds)
if isinstance(image_embeds, str):
embedding = self._connector.\
fetch_image_embedding(image_embeds)
future.set_result(embedding)
placeholder = self._tracker.add("image_embeds", future)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url)
@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
@ -700,6 +787,8 @@ MM_PARSER_MAP: dict[
lambda part: _TextParser(part).get("text", ""),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
"image_embeds":
lambda part: _ImageEmbedsParser(part).get("image_embeds", {}),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"input_audio":
@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part(
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"image_embeds",
"audio_url", "input_audio", "video_url")
@ -843,7 +933,10 @@ def _parse_chat_message_content_part(
str_content = cast(str, content)
mm_parser.parse_image(str_content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "image_embeds":
content = cast(Union[str, dict[str, str]], content)
mm_parser.parse_image_embeds(content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "audio_url":
str_content = cast(str, content)
mm_parser.parse_audio(str_content)

View File

@ -134,3 +134,22 @@ class ImageMediaIO(MediaIO[Image.Image]):
data = buffer.getvalue()
return base64.b64encode(data).decode('utf-8')
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
def __init__(self) -> None:
super().__init__()
def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
return torch.load(buffer, weights_only=True)
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath)
def encode_base64(self, media: torch.Tensor) -> str:
return base64.b64encode(media.numpy()).decode('utf-8')

View File

@ -7,6 +7,7 @@ from urllib.parse import ParseResult, urlparse
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image
import vllm.envs as envs
@ -16,7 +17,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .audio import AudioMediaIO
from .base import MediaIO
from .image import ImageMediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .inputs import PlaceholderRange
from .video import VideoMediaIO
@ -245,6 +246,17 @@ class MediaConnector:
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
def fetch_image_embedding(
self,
data: str,
) -> torch.Tensor:
"""
Load image embedding from a URL.
"""
image_embedding_io = ImageEmbeddingMediaIO()
return image_embedding_io.load_base64("", data)
global_media_connector = MediaConnector()
"""The global :class:`MediaConnector` instance used by vLLM."""