[Frontend] support image embeds (#13955)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey 2025-03-10 20:36:03 +08:00 committed by GitHub
parent 60a98b2de5
commit b0746fae3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 200 additions and 11 deletions

View File

@ -462,4 +462,69 @@ export VLLM_AUDIO_FETCH_TIMEOUT=<timeout>
### Embedding Inputs ### Embedding Inputs
TBD To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape to the corresponding field of the multi-modal dictionary.
#### Image Embedding Inputs
For image embeddings, you can pass the base64-encoded tensor to the `image_embeds` field.
The following example demonstrates how to pass image embeddings to the OpenAI server:
```python
image_embedding = torch.load(...)
grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct
buffer = io.BytesIO()
torch.save(image_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_image_embedding = base64.b64encode(binary_data).decode('utf-8')
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
# Basic usage - this is equivalent to the LLaVA example for offline inference
model = "llava-hf/llava-1.5-7b-hf"
embeds = {
"type": "image_embeds",
"image_embeds": f"{base64_image_embedding}"
}
# Pass additional parameters (available to Qwen2-VL and MiniCPM-V)
model = "Qwen/Qwen2-VL-2B-Instruct"
embeds = {
"type": "image_embeds",
"image_embeds": {
"image_embeds": f"{base64_image_embedding}" , # Required
"image_grid_thw": f"{base64_image_grid_thw}" # Required by Qwen/Qwen2-VL-2B-Instruct
},
}
model = "openbmb/MiniCPM-V-2_6"
embeds = {
"type": "image_embeds",
"image_embeds": {
"image_embeds": f"{base64_image_embedding}" , # Required
"image_sizes": f"{base64_image_sizes}" # Required by openbmb/MiniCPM-V-2_6
},
}
chat_completion = client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{
"type": "text",
"text": "What's in this image?",
},
embeds,
],
},
],
model=model,
)
```
:::{note}
Only one message can contain `{"type": "image_embeds"}`.
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc.
:::

View File

@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
"""The type of the content part.""" """The type of the content part."""
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
image_embeds: Required[Union[str, dict[str, str]]]
"""
The image embeddings. It can be either:
- A single base64 string.
- A dictionary where each value is a base64 string.
"""
type: Required[Literal["image_embeds"]]
"""The type of the content part."""
class VideoURL(TypedDict, total=False): class VideoURL(TypedDict, total=False):
url: Required[str] url: Required[str]
""" """
@ -109,6 +120,7 @@ ChatCompletionContentPartParam: TypeAlias = Union[
ChatCompletionContentPartInputAudioParam, ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentSimpleImageParam, CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam,
CustomChatCompletionContentSimpleAudioParam, CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str] CustomChatCompletionContentSimpleVideoParam, str]
@ -350,7 +362,7 @@ def resolve_chat_template_content_format(
return detected_format return detected_format
ModalityStr = Literal["image", "audio", "video"] ModalityStr = Literal["image", "audio", "video", "image_embeds"]
_T = TypeVar("_T") _T = TypeVar("_T")
@ -391,7 +403,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config = self._model_config.hf_config hf_config = self._model_config.hf_config
model_type = hf_config.model_type model_type = hf_config.model_type
if modality == "image": if modality in ["image", "image_embeds"]:
if model_type == "phi3_v": if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer # Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>" return f"<|image_{current_count}|>"
@ -470,10 +482,27 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
class MultiModalItemTracker(BaseMultiModalItemTracker[object]): class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> Optional[MultiModalDataDict]: def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items_by_modality: if not self._items_by_modality:
return dict(self._items_by_modality) return None
mm_inputs = {}
items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(\
"Mixing raw image and embedding inputs is not allowed")
return None if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError(\
"Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
elif "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
elif "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
elif "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self) return MultiModalContentParser(self)
@ -482,13 +511,31 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]: async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items_by_modality: if not self._items_by_modality:
return { return None
mm_inputs = {}
items_by_modality = {
modality: await asyncio.gather(*items) modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items() for modality, items in self._items_by_modality.items()
} }
return None if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(
"Mixing raw image and embedding inputs is not allowed")
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError(
"Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
elif "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
elif "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
elif "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
return AsyncMultiModalContentParser(self) return AsyncMultiModalContentParser(self)
@ -513,6 +560,11 @@ class BaseMultiModalContentParser(ABC):
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
raise NotImplementedError
@abstractmethod @abstractmethod
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError raise NotImplementedError
@ -543,6 +595,21 @@ class MultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image", image) placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items()
}
placeholder = self._tracker.add("image_embeds", embeds)
if isinstance(image_embeds, str):
embedding = self._connector.fetch_image_embedding(image_embeds)
placeholder = self._tracker.add("image_embeds", embedding)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str) -> None:
audio = self._connector.fetch_audio(audio_url) audio = self._connector.fetch_audio(audio_url)
@ -579,6 +646,25 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image", image_coro) placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items()
}
future.set_result(embeds)
if isinstance(image_embeds, str):
embedding = self._connector.\
fetch_image_embedding(image_embeds)
future.set_result(embedding)
placeholder = self._tracker.add("image_embeds", future)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url) audio_coro = self._connector.fetch_audio_async(audio_url)
@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
# No need to validate using Pydantic again # No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam) _TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam) _ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
@ -700,6 +787,8 @@ MM_PARSER_MAP: dict[
lambda part: _TextParser(part).get("text", ""), lambda part: _TextParser(part).get("text", ""),
"image_url": "image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""), lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
"image_embeds":
lambda part: _ImageEmbedsParser(part).get("image_embeds", {}),
"audio_url": "audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""), lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"input_audio": "input_audio":
@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part(
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"image_embeds",
"audio_url", "input_audio", "video_url") "audio_url", "input_audio", "video_url")
@ -843,7 +933,10 @@ def _parse_chat_message_content_part(
str_content = cast(str, content) str_content = cast(str, content)
mm_parser.parse_image(str_content) mm_parser.parse_image(str_content)
return {'type': 'image'} if wrap_dicts else None return {'type': 'image'} if wrap_dicts else None
if part_type == "image_embeds":
content = cast(Union[str, dict[str, str]], content)
mm_parser.parse_image_embeds(content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "audio_url": if part_type == "audio_url":
str_content = cast(str, content) str_content = cast(str, content)
mm_parser.parse_audio(str_content) mm_parser.parse_audio(str_content)

View File

@ -134,3 +134,22 @@ class ImageMediaIO(MediaIO[Image.Image]):
data = buffer.getvalue() data = buffer.getvalue()
return base64.b64encode(data).decode('utf-8') return base64.b64encode(data).decode('utf-8')
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
def __init__(self) -> None:
super().__init__()
def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
return torch.load(buffer, weights_only=True)
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath)
def encode_base64(self, media: torch.Tensor) -> str:
return base64.b64encode(media.numpy()).decode('utf-8')

View File

@ -7,6 +7,7 @@ from urllib.parse import ParseResult, urlparse
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import torch
from PIL import Image from PIL import Image
import vllm.envs as envs import vllm.envs as envs
@ -16,7 +17,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .audio import AudioMediaIO from .audio import AudioMediaIO
from .base import MediaIO from .base import MediaIO
from .image import ImageMediaIO from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .inputs import PlaceholderRange from .inputs import PlaceholderRange
from .video import VideoMediaIO from .video import VideoMediaIO
@ -245,6 +246,17 @@ class MediaConnector:
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT, fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
) )
def fetch_image_embedding(
self,
data: str,
) -> torch.Tensor:
"""
Load image embedding from a URL.
"""
image_embedding_io = ImageEmbeddingMediaIO()
return image_embedding_io.load_base64("", data)
global_media_connector = MediaConnector() global_media_connector = MediaConnector()
"""The global :class:`MediaConnector` instance used by vLLM.""" """The global :class:`MediaConnector` instance used by vLLM."""