vllm/tests/models/embedding/vision_language/test_dse_qwen2_vl.py
Austin Veselka 1b886aa104
[Model] Adding Support for Qwen2VL as an Embedding Model. Using MrLight/dse-qwen2-2b-mrl-v1 (#9944)
Signed-off-by: FurtherAI <austin.veselka@lighton.ai>
Co-authored-by: FurtherAI <austin.veselka@lighton.ai>
2024-11-13 08:28:13 +00:00

210 lines
6.5 KiB
Python

from functools import partial
from typing import Callable, Dict, List, Type
import pytest
import torch
from PIL import Image
from transformers import BatchEncoding, Qwen2VLForConditionalGeneration
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....utils import large_gpu_test
from ..utils import check_embeddings_close
HF_TEXT_PROMPTS = [
# T -> X
(
"Query: Find me an everyday image that matches the given caption: The label of the object is stop sign", # noqa: E501,
Image.new("RGB", (56, 56))),
# T -> X
("Query: Retrieve an image of this caption: cherry blossom",
Image.new("RGB", (56, 56))),
]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"What is shown in this image?",
"cherry_blossom":
"What is shown in this image?"
})
MODELS = ["MrLight/dse-qwen2-2b-mrl-v1"]
def get_messages(image: Image.Image, text: str, embed_text: bool):
# assert False, 'remember to use outer [] as required'
if embed_text:
messages = [{
"role":
"user",
"content": [
{
"type": "image",
"image": Image.new("RGB", (56, 56)),
"resized_height": 1,
"resized_width": 1
}, # need a dummy image here for an easier process.
{
"type": "text",
"text": text
},
]
}]
else:
messages = [{
"role":
"user",
"content": [{
"type": "image",
"image": image
}, {
"type": "text",
"text": text
}]
}]
return messages
def apply_chat_template_and_add_eos(
messages: List[Dict],
apply_chat_template_fn: Callable,
):
prompt = apply_chat_template_fn(
messages, tokenize=False, add_generation_prompt=True) + "<|endoftext|>"
return prompt
def postprocess_inputs(hf_model: HfRunner, inputs: BatchEncoding, **kwargs):
return hf_model.model.prepare_inputs_for_generation(**inputs, **kwargs)
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
input_texts: List[str],
input_images: PromptImageInput,
embed_texts: List[bool],
model: str,
*,
dtype: str,
) -> None:
'''SET PYTHONPATH'''
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
task="embedding",
dtype=dtype,
enforce_eager=True,
max_model_len=8192) as vllm_model:
tokenizer = vllm_model.model.get_tokenizer()
texts = [
# this is necessary because vllm_model.encode will not apply any
# templating to the prompt, and therefore lacks an image_pad
# token unless one is inserted beforehand (the (28,28) image
# above is converted to an image pad token by the chat template).
apply_chat_template_and_add_eos(
get_messages(image, text, False),
apply_chat_template_fn=tokenizer.apply_chat_template,
) for text, image in zip(input_texts, input_images)
# vllm will replace the pad token with the actual image,
# which may be a placeholder image, later.
]
vllm_outputs = vllm_model.encode(texts, images=input_images)
hf_outputs = []
with hf_runner(model,
dtype=dtype,
auto_cls=Qwen2VLForConditionalGeneration) as hf_model:
hf_model.postprocess_inputs = partial(
postprocess_inputs,
hf_model,
cache_position=torch.arange(
0,
1, # 1 for batch size
requires_grad=False),
use_cache=False)
for text, image, embed_text in zip(input_texts, input_images,
embed_texts):
# dse requires non-standard input processing
# because it needs an image_pad token
messages = get_messages(image, text, embed_text)
prompt = apply_chat_template_and_add_eos(
messages, hf_model.processor.apply_chat_template)
inputs = hf_model.get_inputs(
prompts=[[prompt]],
images=[[image]],
)
with torch.no_grad():
outputs = hf_model.model(
**hf_model.wrap_device(inputs[0],
device=hf_model.model.device.type),
return_dict=True,
output_hidden_states=True,
)
pooled_output = torch.nn.functional.normalize(
outputs.hidden_states[-1][0, -1], p=2, dim=-1)
hf_outputs.append(pooled_output.tolist())
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_models_text(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [(text, image_placeholder)
for text, image_placeholder in HF_TEXT_PROMPTS]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
embed_texts = [True] * len(input_texts)
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images, # type: ignore
embed_texts,
model,
dtype=dtype,
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [
(text, asset.pil_image)
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
embed_texts = [False] * len(input_texts)
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images,
embed_texts,
model,
dtype=dtype,
)