[Frontend] support image embeds (#13955)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
60a98b2de5
commit
b0746fae3d
@ -462,4 +462,69 @@ export VLLM_AUDIO_FETCH_TIMEOUT=<timeout>
|
|||||||
|
|
||||||
### Embedding Inputs
|
### Embedding Inputs
|
||||||
|
|
||||||
TBD
|
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
|
||||||
|
pass a tensor of shape to the corresponding field of the multi-modal dictionary.
|
||||||
|
#### Image Embedding Inputs
|
||||||
|
For image embeddings, you can pass the base64-encoded tensor to the `image_embeds` field.
|
||||||
|
The following example demonstrates how to pass image embeddings to the OpenAI server:
|
||||||
|
|
||||||
|
```python
|
||||||
|
image_embedding = torch.load(...)
|
||||||
|
grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct
|
||||||
|
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.save(image_embedding, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
binary_data = buffer.read()
|
||||||
|
base64_image_embedding = base64.b64encode(binary_data).decode('utf-8')
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Basic usage - this is equivalent to the LLaVA example for offline inference
|
||||||
|
model = "llava-hf/llava-1.5-7b-hf"
|
||||||
|
embeds = {
|
||||||
|
"type": "image_embeds",
|
||||||
|
"image_embeds": f"{base64_image_embedding}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Pass additional parameters (available to Qwen2-VL and MiniCPM-V)
|
||||||
|
model = "Qwen/Qwen2-VL-2B-Instruct"
|
||||||
|
embeds = {
|
||||||
|
"type": "image_embeds",
|
||||||
|
"image_embeds": {
|
||||||
|
"image_embeds": f"{base64_image_embedding}" , # Required
|
||||||
|
"image_grid_thw": f"{base64_image_grid_thw}" # Required by Qwen/Qwen2-VL-2B-Instruct
|
||||||
|
},
|
||||||
|
}
|
||||||
|
model = "openbmb/MiniCPM-V-2_6"
|
||||||
|
embeds = {
|
||||||
|
"type": "image_embeds",
|
||||||
|
"image_embeds": {
|
||||||
|
"image_embeds": f"{base64_image_embedding}" , # Required
|
||||||
|
"image_sizes": f"{base64_image_sizes}" # Required by openbmb/MiniCPM-V-2_6
|
||||||
|
},
|
||||||
|
}
|
||||||
|
chat_completion = client.chat.completions.create(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?",
|
||||||
|
},
|
||||||
|
embeds,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
:::{note}
|
||||||
|
Only one message can contain `{"type": "image_embeds"}`.
|
||||||
|
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc.
|
||||||
|
:::
|
||||||
|
@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
|
|||||||
"""The type of the content part."""
|
"""The type of the content part."""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
|
||||||
|
image_embeds: Required[Union[str, dict[str, str]]]
|
||||||
|
"""
|
||||||
|
The image embeddings. It can be either:
|
||||||
|
- A single base64 string.
|
||||||
|
- A dictionary where each value is a base64 string.
|
||||||
|
"""
|
||||||
|
type: Required[Literal["image_embeds"]]
|
||||||
|
"""The type of the content part."""
|
||||||
|
|
||||||
|
|
||||||
class VideoURL(TypedDict, total=False):
|
class VideoURL(TypedDict, total=False):
|
||||||
url: Required[str]
|
url: Required[str]
|
||||||
"""
|
"""
|
||||||
@ -109,6 +120,7 @@ ChatCompletionContentPartParam: TypeAlias = Union[
|
|||||||
ChatCompletionContentPartInputAudioParam,
|
ChatCompletionContentPartInputAudioParam,
|
||||||
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
|
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
|
||||||
CustomChatCompletionContentSimpleImageParam,
|
CustomChatCompletionContentSimpleImageParam,
|
||||||
|
ChatCompletionContentPartImageEmbedsParam,
|
||||||
CustomChatCompletionContentSimpleAudioParam,
|
CustomChatCompletionContentSimpleAudioParam,
|
||||||
CustomChatCompletionContentSimpleVideoParam, str]
|
CustomChatCompletionContentSimpleVideoParam, str]
|
||||||
|
|
||||||
@ -350,7 +362,7 @@ def resolve_chat_template_content_format(
|
|||||||
return detected_format
|
return detected_format
|
||||||
|
|
||||||
|
|
||||||
ModalityStr = Literal["image", "audio", "video"]
|
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
@ -391,7 +403,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
hf_config = self._model_config.hf_config
|
hf_config = self._model_config.hf_config
|
||||||
model_type = hf_config.model_type
|
model_type = hf_config.model_type
|
||||||
|
|
||||||
if modality == "image":
|
if modality in ["image", "image_embeds"]:
|
||||||
if model_type == "phi3_v":
|
if model_type == "phi3_v":
|
||||||
# Workaround since this token is not defined in the tokenizer
|
# Workaround since this token is not defined in the tokenizer
|
||||||
return f"<|image_{current_count}|>"
|
return f"<|image_{current_count}|>"
|
||||||
@ -470,10 +482,27 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||||
|
|
||||||
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||||
if self._items_by_modality:
|
if not self._items_by_modality:
|
||||||
return dict(self._items_by_modality)
|
return None
|
||||||
|
mm_inputs = {}
|
||||||
|
items_by_modality = dict(self._items_by_modality)
|
||||||
|
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||||
|
raise ValueError(\
|
||||||
|
"Mixing raw image and embedding inputs is not allowed")
|
||||||
|
|
||||||
return None
|
if "image_embeds" in items_by_modality:
|
||||||
|
image_embeds_lst = items_by_modality["image_embeds"]
|
||||||
|
if len(image_embeds_lst) > 1:
|
||||||
|
raise ValueError(\
|
||||||
|
"Only one message can have {'type': 'image_embeds'}")
|
||||||
|
mm_inputs["image"] = image_embeds_lst[0]
|
||||||
|
elif "image" in items_by_modality:
|
||||||
|
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||||
|
elif "audio" in items_by_modality:
|
||||||
|
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||||
|
elif "video" in items_by_modality:
|
||||||
|
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
return MultiModalContentParser(self)
|
return MultiModalContentParser(self)
|
||||||
@ -482,13 +511,31 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
|||||||
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||||
|
|
||||||
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||||
if self._items_by_modality:
|
if not self._items_by_modality:
|
||||||
return {
|
return None
|
||||||
|
mm_inputs = {}
|
||||||
|
items_by_modality = {
|
||||||
modality: await asyncio.gather(*items)
|
modality: await asyncio.gather(*items)
|
||||||
for modality, items in self._items_by_modality.items()
|
for modality, items in self._items_by_modality.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
return None
|
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||||
|
raise ValueError(
|
||||||
|
"Mixing raw image and embedding inputs is not allowed")
|
||||||
|
|
||||||
|
if "image_embeds" in items_by_modality:
|
||||||
|
image_embeds_lst = items_by_modality["image_embeds"]
|
||||||
|
if len(image_embeds_lst) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Only one message can have {'type': 'image_embeds'}")
|
||||||
|
mm_inputs["image"] = image_embeds_lst[0]
|
||||||
|
elif "image" in items_by_modality:
|
||||||
|
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||||
|
elif "audio" in items_by_modality:
|
||||||
|
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||||
|
elif "video" in items_by_modality:
|
||||||
|
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
return AsyncMultiModalContentParser(self)
|
return AsyncMultiModalContentParser(self)
|
||||||
@ -513,6 +560,11 @@ class BaseMultiModalContentParser(ABC):
|
|||||||
def parse_image(self, image_url: str) -> None:
|
def parse_image(self, image_url: str) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_image_embeds(self,
|
||||||
|
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_audio(self, audio_url: str) -> None:
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -543,6 +595,21 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
placeholder = self._tracker.add("image", image)
|
placeholder = self._tracker.add("image", image)
|
||||||
self._add_placeholder(placeholder)
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
def parse_image_embeds(self,
|
||||||
|
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||||
|
if isinstance(image_embeds, dict):
|
||||||
|
embeds = {
|
||||||
|
k: self._connector.fetch_image_embedding(v)
|
||||||
|
for k, v in image_embeds.items()
|
||||||
|
}
|
||||||
|
placeholder = self._tracker.add("image_embeds", embeds)
|
||||||
|
|
||||||
|
if isinstance(image_embeds, str):
|
||||||
|
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||||
|
placeholder = self._tracker.add("image_embeds", embedding)
|
||||||
|
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
def parse_audio(self, audio_url: str) -> None:
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
audio = self._connector.fetch_audio(audio_url)
|
audio = self._connector.fetch_audio(audio_url)
|
||||||
|
|
||||||
@ -579,6 +646,25 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
placeholder = self._tracker.add("image", image_coro)
|
placeholder = self._tracker.add("image", image_coro)
|
||||||
self._add_placeholder(placeholder)
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
def parse_image_embeds(self,
|
||||||
|
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||||
|
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
|
||||||
|
|
||||||
|
if isinstance(image_embeds, dict):
|
||||||
|
embeds = {
|
||||||
|
k: self._connector.fetch_image_embedding(v)
|
||||||
|
for k, v in image_embeds.items()
|
||||||
|
}
|
||||||
|
future.set_result(embeds)
|
||||||
|
|
||||||
|
if isinstance(image_embeds, str):
|
||||||
|
embedding = self._connector.\
|
||||||
|
fetch_image_embedding(image_embeds)
|
||||||
|
future.set_result(embedding)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("image_embeds", future)
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
def parse_audio(self, audio_url: str) -> None:
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
audio_coro = self._connector.fetch_audio_async(audio_url)
|
audio_coro = self._connector.fetch_audio_async(audio_url)
|
||||||
|
|
||||||
@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
|
|||||||
# No need to validate using Pydantic again
|
# No need to validate using Pydantic again
|
||||||
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||||
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||||
|
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
|
||||||
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
|
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
|
||||||
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||||
@ -700,6 +787,8 @@ MM_PARSER_MAP: dict[
|
|||||||
lambda part: _TextParser(part).get("text", ""),
|
lambda part: _TextParser(part).get("text", ""),
|
||||||
"image_url":
|
"image_url":
|
||||||
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
|
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
|
||||||
|
"image_embeds":
|
||||||
|
lambda part: _ImageEmbedsParser(part).get("image_embeds", {}),
|
||||||
"audio_url":
|
"audio_url":
|
||||||
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
|
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
|
||||||
"input_audio":
|
"input_audio":
|
||||||
@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part(
|
|||||||
|
|
||||||
|
|
||||||
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
|
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
|
||||||
|
"image_embeds",
|
||||||
"audio_url", "input_audio", "video_url")
|
"audio_url", "input_audio", "video_url")
|
||||||
|
|
||||||
|
|
||||||
@ -843,7 +933,10 @@ def _parse_chat_message_content_part(
|
|||||||
str_content = cast(str, content)
|
str_content = cast(str, content)
|
||||||
mm_parser.parse_image(str_content)
|
mm_parser.parse_image(str_content)
|
||||||
return {'type': 'image'} if wrap_dicts else None
|
return {'type': 'image'} if wrap_dicts else None
|
||||||
|
if part_type == "image_embeds":
|
||||||
|
content = cast(Union[str, dict[str, str]], content)
|
||||||
|
mm_parser.parse_image_embeds(content)
|
||||||
|
return {'type': 'image'} if wrap_dicts else None
|
||||||
if part_type == "audio_url":
|
if part_type == "audio_url":
|
||||||
str_content = cast(str, content)
|
str_content = cast(str, content)
|
||||||
mm_parser.parse_audio(str_content)
|
mm_parser.parse_audio(str_content)
|
||||||
|
@ -134,3 +134,22 @@ class ImageMediaIO(MediaIO[Image.Image]):
|
|||||||
data = buffer.getvalue()
|
data = buffer.getvalue()
|
||||||
|
|
||||||
return base64.b64encode(data).decode('utf-8')
|
return base64.b64encode(data).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||||
|
buffer = BytesIO(data)
|
||||||
|
return torch.load(buffer, weights_only=True)
|
||||||
|
|
||||||
|
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||||
|
return self.load_bytes(base64.b64decode(data))
|
||||||
|
|
||||||
|
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||||
|
return torch.load(filepath)
|
||||||
|
|
||||||
|
def encode_base64(self, media: torch.Tensor) -> str:
|
||||||
|
return base64.b64encode(media.numpy()).decode('utf-8')
|
||||||
|
@ -7,6 +7,7 @@ from urllib.parse import ParseResult, urlparse
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -16,7 +17,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|||||||
|
|
||||||
from .audio import AudioMediaIO
|
from .audio import AudioMediaIO
|
||||||
from .base import MediaIO
|
from .base import MediaIO
|
||||||
from .image import ImageMediaIO
|
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
||||||
from .inputs import PlaceholderRange
|
from .inputs import PlaceholderRange
|
||||||
from .video import VideoMediaIO
|
from .video import VideoMediaIO
|
||||||
|
|
||||||
@ -245,6 +246,17 @@ class MediaConnector:
|
|||||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def fetch_image_embedding(
|
||||||
|
self,
|
||||||
|
data: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Load image embedding from a URL.
|
||||||
|
"""
|
||||||
|
image_embedding_io = ImageEmbeddingMediaIO()
|
||||||
|
|
||||||
|
return image_embedding_io.load_base64("", data)
|
||||||
|
|
||||||
|
|
||||||
global_media_connector = MediaConnector()
|
global_media_connector = MediaConnector()
|
||||||
"""The global :class:`MediaConnector` instance used by vLLM."""
|
"""The global :class:`MediaConnector` instance used by vLLM."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user