[Model] Adding Support for Qwen2VL as an Embedding Model. Using MrLight/dse-qwen2-2b-mrl-v1 (#9944)

Signed-off-by: FurtherAI <austin.veselka@lighton.ai>
Co-authored-by: FurtherAI <austin.veselka@lighton.ai>
This commit is contained in:
Austin Veselka 2024-11-13 02:28:13 -06:00 committed by GitHub
parent 3945c82346
commit 1b886aa104
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 363 additions and 18 deletions

View File

@ -584,6 +584,12 @@ Multimodal Embedding
- :code:`TIGER-Lab/VLM2Vec-Full`
- 🚧
- ✅︎
* - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL-based
- T + I
- :code:`MrLight/dse-qwen2-2b-mrl-v1`
-
- ✅︎
.. important::
Some model architectures support both generation and embedding tasks.

View File

@ -310,4 +310,21 @@ Since the request schema is not defined by OpenAI client, we post a request to t
response_json = response.json()
print("Embedding output:", response_json["data"][0]["embedding"])
Here is an example for serving the ``MrLight/dse-qwen2-2b-mrl-v1`` model.
.. code-block:: bash
vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embedding \
--trust-remote-code --max-model-len 8192 --chat-template examples/template_dse_qwen2_vl.jinja
.. important::
Like with VLM2Vec, we have to explicitly pass ``--task embedding``. Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings,
which is handled by the jinja template.
.. important::
Also important, ``MrLight/dse-qwen2-2b-mrl-v1`` requires a placeholder image of the minimum image size for text query embeddings. See the full code
example below for details.
A full code example can be found in `examples/openai_chat_embedding_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_embedding_client_for_multimodal.py>`_.

View File

@ -1,33 +1,120 @@
import argparse
import base64
import io
import requests
from PIL import Image
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
response = requests.post(
"http://localhost:8000/v1/embeddings",
json={
"model":
"TIGER-Lab/VLM2Vec-Full",
"messages": [{
def vlm2vec():
response = requests.post(
"http://localhost:8000/v1/embeddings",
json={
"model":
"TIGER-Lab/VLM2Vec-Full",
"messages": [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "Represent the given image."
},
],
}],
"encoding_format":
"float",
},
)
response.raise_for_status()
response_json = response.json()
print("Embedding output:", response_json["data"][0]["embedding"])
def dse_qwen2_vl(inp: dict):
# Embedding an Image
if inp["dtype"] == "image":
messages = [{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": inp["image_url"],
}
}, {
"type": "text",
"text": "What is shown in this image?"
}]
}]
# Embedding a Text Query
else:
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
# of the minimum input size
buffer = io.BytesIO()
image_placeholder = Image.new("RGB", (56, 56))
image_placeholder.save(buffer, "png")
buffer.seek(0)
image_placeholder = base64.b64encode(buffer.read()).decode('utf-8')
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
"url": f"data:image/jpeg;base64,{image_placeholder}",
}
},
{
"type": "text",
"text": "Represent the given image."
"text": f"Query: {inp['content']}"
},
],
}],
"encoding_format":
"float",
},
)
response.raise_for_status()
response_json = response.json()
]
}]
print("Embedding output:", response_json["data"][0]["embedding"])
response = requests.post(
"http://localhost:8000/v1/embeddings",
json={
"model": "MrLight/dse-qwen2-2b-mrl-v1",
"messages": messages,
"encoding_format": "float",
},
)
response.raise_for_status()
response_json = response.json()
print("Embedding output:", response_json["data"][0]["embedding"])
if __name__ == '__main__':
parser = argparse.ArgumentParser(
"Script to call a specified VLM through the API. Make sure to serve "
"the model with --task embedding before running this.")
parser.add_argument("model",
type=str,
choices=["vlm2vec", "dse_qwen2_vl"],
required=True,
help="Which model to call.")
args = parser.parse_args()
if args.model == "vlm2vec":
vlm2vec()
elif args.model == "dse_qwen2_vl":
dse_qwen2_vl({
"dtye": "image",
"image_url": image_url,
})
dse_qwen2_vl({
"dtype": "text",
"content": "What is the weather like today?",
})

View File

@ -0,0 +1,7 @@
{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{% raw %}<|im_start|>system
You are a helpful assistant.<|im_end|>
{% endraw %}{% endif %}<|im_start|>{{ message['role'] }}{% raw %}
{% endraw %}{% if message['content'] is string %}{{ message['content'] }}<|im_end|>{% raw %}
{% endraw %}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>{% raw %}
{% endraw %}{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant{% raw %}
{% endraw %}{% endif %}<|endoftext|>

View File

@ -243,6 +243,9 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
class HfRunner:
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
if x is None or isinstance(x, (bool, )):
return x
if device is None:
device = "cpu" if current_platform.is_cpu() else "cuda"

View File

@ -0,0 +1,209 @@
from functools import partial
from typing import Callable, Dict, List, Type
import pytest
import torch
from PIL import Image
from transformers import BatchEncoding, Qwen2VLForConditionalGeneration
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....utils import large_gpu_test
from ..utils import check_embeddings_close
HF_TEXT_PROMPTS = [
# T -> X
(
"Query: Find me an everyday image that matches the given caption: The label of the object is stop sign", # noqa: E501,
Image.new("RGB", (56, 56))),
# T -> X
("Query: Retrieve an image of this caption: cherry blossom",
Image.new("RGB", (56, 56))),
]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"What is shown in this image?",
"cherry_blossom":
"What is shown in this image?"
})
MODELS = ["MrLight/dse-qwen2-2b-mrl-v1"]
def get_messages(image: Image.Image, text: str, embed_text: bool):
# assert False, 'remember to use outer [] as required'
if embed_text:
messages = [{
"role":
"user",
"content": [
{
"type": "image",
"image": Image.new("RGB", (56, 56)),
"resized_height": 1,
"resized_width": 1
}, # need a dummy image here for an easier process.
{
"type": "text",
"text": text
},
]
}]
else:
messages = [{
"role":
"user",
"content": [{
"type": "image",
"image": image
}, {
"type": "text",
"text": text
}]
}]
return messages
def apply_chat_template_and_add_eos(
messages: List[Dict],
apply_chat_template_fn: Callable,
):
prompt = apply_chat_template_fn(
messages, tokenize=False, add_generation_prompt=True) + "<|endoftext|>"
return prompt
def postprocess_inputs(hf_model: HfRunner, inputs: BatchEncoding, **kwargs):
return hf_model.model.prepare_inputs_for_generation(**inputs, **kwargs)
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
input_texts: List[str],
input_images: PromptImageInput,
embed_texts: List[bool],
model: str,
*,
dtype: str,
) -> None:
'''SET PYTHONPATH'''
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
task="embedding",
dtype=dtype,
enforce_eager=True,
max_model_len=8192) as vllm_model:
tokenizer = vllm_model.model.get_tokenizer()
texts = [
# this is necessary because vllm_model.encode will not apply any
# templating to the prompt, and therefore lacks an image_pad
# token unless one is inserted beforehand (the (28,28) image
# above is converted to an image pad token by the chat template).
apply_chat_template_and_add_eos(
get_messages(image, text, False),
apply_chat_template_fn=tokenizer.apply_chat_template,
) for text, image in zip(input_texts, input_images)
# vllm will replace the pad token with the actual image,
# which may be a placeholder image, later.
]
vllm_outputs = vllm_model.encode(texts, images=input_images)
hf_outputs = []
with hf_runner(model,
dtype=dtype,
auto_cls=Qwen2VLForConditionalGeneration) as hf_model:
hf_model.postprocess_inputs = partial(
postprocess_inputs,
hf_model,
cache_position=torch.arange(
0,
1, # 1 for batch size
requires_grad=False),
use_cache=False)
for text, image, embed_text in zip(input_texts, input_images,
embed_texts):
# dse requires non-standard input processing
# because it needs an image_pad token
messages = get_messages(image, text, embed_text)
prompt = apply_chat_template_and_add_eos(
messages, hf_model.processor.apply_chat_template)
inputs = hf_model.get_inputs(
prompts=[[prompt]],
images=[[image]],
)
with torch.no_grad():
outputs = hf_model.model(
**hf_model.wrap_device(inputs[0],
device=hf_model.model.device.type),
return_dict=True,
output_hidden_states=True,
)
pooled_output = torch.nn.functional.normalize(
outputs.hidden_states[-1][0, -1], p=2, dim=-1)
hf_outputs.append(pooled_output.tolist())
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_models_text(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [(text, image_placeholder)
for text, image_placeholder in HF_TEXT_PROMPTS]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
embed_texts = [True] * len(input_texts)
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images, # type: ignore
embed_texts,
model,
dtype=dtype,
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [
(text, asset.pil_image)
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
embed_texts = [False] * len(input_texts)
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images,
embed_texts,
model,
dtype=dtype,
)

View File

@ -51,6 +51,7 @@ from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import (GPTQConfig,
GPTQMarlinConfig,
QuantizationConfig)
@ -58,12 +59,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs)
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor
@ -1067,6 +1069,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
assert not cache_config.enable_prefix_caching, \
"Qwen2-VL currently does not support prefix caching"
@ -1098,6 +1101,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
@ -1318,6 +1326,13 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)

View File

@ -109,6 +109,7 @@ _EMBEDDING_MODELS = {
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
}
_MULTIMODAL_MODELS = {