[Model] Adding Support for Qwen2VL as an Embedding Model. Using MrLight/dse-qwen2-2b-mrl-v1 (#9944)
Signed-off-by: FurtherAI <austin.veselka@lighton.ai> Co-authored-by: FurtherAI <austin.veselka@lighton.ai>
This commit is contained in:
parent
3945c82346
commit
1b886aa104
@ -584,6 +584,12 @@ Multimodal Embedding
|
||||
- :code:`TIGER-Lab/VLM2Vec-Full`
|
||||
- 🚧
|
||||
- ✅︎
|
||||
* - :code:`Qwen2VLForConditionalGeneration`
|
||||
- Qwen2-VL-based
|
||||
- T + I
|
||||
- :code:`MrLight/dse-qwen2-2b-mrl-v1`
|
||||
-
|
||||
- ✅︎
|
||||
|
||||
.. important::
|
||||
Some model architectures support both generation and embedding tasks.
|
||||
|
@ -310,4 +310,21 @@ Since the request schema is not defined by OpenAI client, we post a request to t
|
||||
response_json = response.json()
|
||||
print("Embedding output:", response_json["data"][0]["embedding"])
|
||||
|
||||
Here is an example for serving the ``MrLight/dse-qwen2-2b-mrl-v1`` model.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embedding \
|
||||
--trust-remote-code --max-model-len 8192 --chat-template examples/template_dse_qwen2_vl.jinja
|
||||
|
||||
.. important::
|
||||
|
||||
Like with VLM2Vec, we have to explicitly pass ``--task embedding``. Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings,
|
||||
which is handled by the jinja template.
|
||||
|
||||
.. important::
|
||||
|
||||
Also important, ``MrLight/dse-qwen2-2b-mrl-v1`` requires a placeholder image of the minimum image size for text query embeddings. See the full code
|
||||
example below for details.
|
||||
|
||||
A full code example can be found in `examples/openai_chat_embedding_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_embedding_client_for_multimodal.py>`_.
|
||||
|
@ -1,8 +1,15 @@
|
||||
import argparse
|
||||
import base64
|
||||
import io
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
|
||||
response = requests.post(
|
||||
|
||||
def vlm2vec():
|
||||
response = requests.post(
|
||||
"http://localhost:8000/v1/embeddings",
|
||||
json={
|
||||
"model":
|
||||
@ -26,8 +33,88 @@ response = requests.post(
|
||||
"encoding_format":
|
||||
"float",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
|
||||
print("Embedding output:", response_json["data"][0]["embedding"])
|
||||
print("Embedding output:", response_json["data"][0]["embedding"])
|
||||
|
||||
|
||||
def dse_qwen2_vl(inp: dict):
|
||||
# Embedding an Image
|
||||
if inp["dtype"] == "image":
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": inp["image_url"],
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What is shown in this image?"
|
||||
}]
|
||||
}]
|
||||
# Embedding a Text Query
|
||||
else:
|
||||
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
|
||||
# of the minimum input size
|
||||
buffer = io.BytesIO()
|
||||
image_placeholder = Image.new("RGB", (56, 56))
|
||||
image_placeholder.save(buffer, "png")
|
||||
buffer.seek(0)
|
||||
image_placeholder = base64.b64encode(buffer.read()).decode('utf-8')
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_placeholder}",
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Query: {inp['content']}"
|
||||
},
|
||||
]
|
||||
}]
|
||||
|
||||
response = requests.post(
|
||||
"http://localhost:8000/v1/embeddings",
|
||||
json={
|
||||
"model": "MrLight/dse-qwen2-2b-mrl-v1",
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
|
||||
print("Embedding output:", response_json["data"][0]["embedding"])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
"Script to call a specified VLM through the API. Make sure to serve "
|
||||
"the model with --task embedding before running this.")
|
||||
parser.add_argument("model",
|
||||
type=str,
|
||||
choices=["vlm2vec", "dse_qwen2_vl"],
|
||||
required=True,
|
||||
help="Which model to call.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model == "vlm2vec":
|
||||
vlm2vec()
|
||||
elif args.model == "dse_qwen2_vl":
|
||||
dse_qwen2_vl({
|
||||
"dtye": "image",
|
||||
"image_url": image_url,
|
||||
})
|
||||
dse_qwen2_vl({
|
||||
"dtype": "text",
|
||||
"content": "What is the weather like today?",
|
||||
})
|
||||
|
7
examples/template_dse_qwen2_vl.jinja
Normal file
7
examples/template_dse_qwen2_vl.jinja
Normal file
@ -0,0 +1,7 @@
|
||||
{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{% raw %}<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
{% endraw %}{% endif %}<|im_start|>{{ message['role'] }}{% raw %}
|
||||
{% endraw %}{% if message['content'] is string %}{{ message['content'] }}<|im_end|>{% raw %}
|
||||
{% endraw %}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>{% raw %}
|
||||
{% endraw %}{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant{% raw %}
|
||||
{% endraw %}{% endif %}<|endoftext|>
|
@ -243,6 +243,9 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
||||
class HfRunner:
|
||||
|
||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||
if x is None or isinstance(x, (bool, )):
|
||||
return x
|
||||
|
||||
if device is None:
|
||||
device = "cpu" if current_platform.is_cpu() else "cuda"
|
||||
|
||||
|
209
tests/models/embedding/vision_language/test_dse_qwen2_vl.py
Normal file
209
tests/models/embedding/vision_language/test_dse_qwen2_vl.py
Normal file
@ -0,0 +1,209 @@
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import BatchEncoding, Qwen2VLForConditionalGeneration
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||
from ....utils import large_gpu_test
|
||||
from ..utils import check_embeddings_close
|
||||
|
||||
HF_TEXT_PROMPTS = [
|
||||
# T -> X
|
||||
(
|
||||
"Query: Find me an everyday image that matches the given caption: The label of the object is stop sign", # noqa: E501,
|
||||
Image.new("RGB", (56, 56))),
|
||||
# T -> X
|
||||
("Query: Retrieve an image of this caption: cherry blossom",
|
||||
Image.new("RGB", (56, 56))),
|
||||
]
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"What is shown in this image?",
|
||||
"cherry_blossom":
|
||||
"What is shown in this image?"
|
||||
})
|
||||
|
||||
MODELS = ["MrLight/dse-qwen2-2b-mrl-v1"]
|
||||
|
||||
|
||||
def get_messages(image: Image.Image, text: str, embed_text: bool):
|
||||
# assert False, 'remember to use outer [] as required'
|
||||
if embed_text:
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": Image.new("RGB", (56, 56)),
|
||||
"resized_height": 1,
|
||||
"resized_width": 1
|
||||
}, # need a dummy image here for an easier process.
|
||||
{
|
||||
"type": "text",
|
||||
"text": text
|
||||
},
|
||||
]
|
||||
}]
|
||||
else:
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image",
|
||||
"image": image
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": text
|
||||
}]
|
||||
}]
|
||||
return messages
|
||||
|
||||
|
||||
def apply_chat_template_and_add_eos(
|
||||
messages: List[Dict],
|
||||
apply_chat_template_fn: Callable,
|
||||
):
|
||||
prompt = apply_chat_template_fn(
|
||||
messages, tokenize=False, add_generation_prompt=True) + "<|endoftext|>"
|
||||
return prompt
|
||||
|
||||
|
||||
def postprocess_inputs(hf_model: HfRunner, inputs: BatchEncoding, **kwargs):
|
||||
return hf_model.model.prepare_inputs_for_generation(**inputs, **kwargs)
|
||||
|
||||
|
||||
def _run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
input_texts: List[str],
|
||||
input_images: PromptImageInput,
|
||||
embed_texts: List[bool],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
'''SET PYTHONPATH'''
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
with vllm_runner(model,
|
||||
task="embedding",
|
||||
dtype=dtype,
|
||||
enforce_eager=True,
|
||||
max_model_len=8192) as vllm_model:
|
||||
tokenizer = vllm_model.model.get_tokenizer()
|
||||
texts = [
|
||||
# this is necessary because vllm_model.encode will not apply any
|
||||
# templating to the prompt, and therefore lacks an image_pad
|
||||
# token unless one is inserted beforehand (the (28,28) image
|
||||
# above is converted to an image pad token by the chat template).
|
||||
apply_chat_template_and_add_eos(
|
||||
get_messages(image, text, False),
|
||||
apply_chat_template_fn=tokenizer.apply_chat_template,
|
||||
) for text, image in zip(input_texts, input_images)
|
||||
# vllm will replace the pad token with the actual image,
|
||||
# which may be a placeholder image, later.
|
||||
]
|
||||
vllm_outputs = vllm_model.encode(texts, images=input_images)
|
||||
|
||||
hf_outputs = []
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=Qwen2VLForConditionalGeneration) as hf_model:
|
||||
hf_model.postprocess_inputs = partial(
|
||||
postprocess_inputs,
|
||||
hf_model,
|
||||
cache_position=torch.arange(
|
||||
0,
|
||||
1, # 1 for batch size
|
||||
requires_grad=False),
|
||||
use_cache=False)
|
||||
for text, image, embed_text in zip(input_texts, input_images,
|
||||
embed_texts):
|
||||
# dse requires non-standard input processing
|
||||
# because it needs an image_pad token
|
||||
messages = get_messages(image, text, embed_text)
|
||||
prompt = apply_chat_template_and_add_eos(
|
||||
messages, hf_model.processor.apply_chat_template)
|
||||
inputs = hf_model.get_inputs(
|
||||
prompts=[[prompt]],
|
||||
images=[[image]],
|
||||
)
|
||||
with torch.no_grad():
|
||||
outputs = hf_model.model(
|
||||
**hf_model.wrap_device(inputs[0],
|
||||
device=hf_model.model.device.type),
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
pooled_output = torch.nn.functional.normalize(
|
||||
outputs.hidden_states[-1][0, -1], p=2, dim=-1)
|
||||
hf_outputs.append(pooled_output.tolist())
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
def test_models_text(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
input_texts_images = [(text, image_placeholder)
|
||||
for text, image_placeholder in HF_TEXT_PROMPTS]
|
||||
input_texts = [text for text, _ in input_texts_images]
|
||||
input_images = [image for _, image in input_texts_images]
|
||||
embed_texts = [True] * len(input_texts)
|
||||
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
input_texts,
|
||||
input_images, # type: ignore
|
||||
embed_texts,
|
||||
model,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
def test_models_image(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
input_texts_images = [
|
||||
(text, asset.pil_image)
|
||||
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
|
||||
]
|
||||
input_texts = [text for text, _ in input_texts_images]
|
||||
input_images = [image for _, image in input_texts_images]
|
||||
embed_texts = [False] * len(input_texts)
|
||||
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
input_texts,
|
||||
input_images,
|
||||
embed_texts,
|
||||
model,
|
||||
dtype=dtype,
|
||||
)
|
@ -51,6 +51,7 @@ from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import (GPTQConfig,
|
||||
GPTQMarlinConfig,
|
||||
QuantizationConfig)
|
||||
@ -58,12 +59,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalKwargs)
|
||||
from vllm.multimodal.base import MultiModalData
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
|
||||
@ -1067,6 +1069,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Qwen2-VL currently does not support prefix caching"
|
||||
@ -1098,6 +1101,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
@ -1318,6 +1326,13 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
@ -109,6 +109,7 @@ _EMBEDDING_MODELS = {
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user