[Model] Support multiple images for qwen-vl (#8247)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e56bf27741
commit
c6202daeed
@ -254,7 +254,7 @@ Multimodal Language Models
|
||||
-
|
||||
* - :code:`QWenLMHeadModel`
|
||||
- Qwen-VL
|
||||
- Image\ :sup:`E`
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
|
||||
-
|
||||
* - :code:`Qwen2VLForConditionalGeneration`
|
||||
|
@ -19,7 +19,39 @@ IMAGE_URLS = [
|
||||
]
|
||||
|
||||
|
||||
def load_phi3v(question, image_urls: List[str]):
|
||||
def load_qwenvl_chat(question: str, image_urls: List[str]):
|
||||
model_name = "Qwen/Qwen-VL-Chat"
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
placeholders = "".join(f"Picture {i}: <img></img>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
|
||||
# This model does not have a chat_template attribute on its tokenizer,
|
||||
# so we need to explicitly pass it. We use ChatML since it's used in the
|
||||
# generation utils of the model:
|
||||
# https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
|
||||
# Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
|
||||
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
|
||||
|
||||
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template)
|
||||
|
||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
return llm, prompt, stop_token_ids, None, chat_template
|
||||
|
||||
|
||||
def load_phi3v(question: str, image_urls: List[str]):
|
||||
llm = LLM(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
trust_remote_code=True,
|
||||
@ -30,10 +62,10 @@ def load_phi3v(question, image_urls: List[str]):
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids, None
|
||||
return llm, prompt, stop_token_ids, None, None
|
||||
|
||||
|
||||
def load_internvl(question, image_urls: List[str]):
|
||||
def load_internvl(question: str, image_urls: List[str]):
|
||||
model_name = "OpenGVLab/InternVL2-2B"
|
||||
|
||||
llm = LLM(
|
||||
@ -61,7 +93,7 @@ def load_internvl(question, image_urls: List[str]):
|
||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
|
||||
return llm, prompt, stop_token_ids, None
|
||||
return llm, prompt, stop_token_ids, None, None
|
||||
|
||||
|
||||
def load_qwen2_vl(question, image_urls: List[str]):
|
||||
@ -111,18 +143,19 @@ def load_qwen2_vl(question, image_urls: List[str]):
|
||||
else:
|
||||
image_data, _ = process_vision_info(messages)
|
||||
|
||||
return llm, prompt, stop_token_ids, image_data
|
||||
return llm, prompt, stop_token_ids, image_data, None
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"phi3_v": load_phi3v,
|
||||
"internvl_chat": load_internvl,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
"qwen_vl_chat": load_qwenvl_chat,
|
||||
}
|
||||
|
||||
|
||||
def run_generate(model, question: str, image_urls: List[str]):
|
||||
llm, prompt, stop_token_ids, image_data = model_example_map[model](
|
||||
llm, prompt, stop_token_ids, image_data, _ = model_example_map[model](
|
||||
question, image_urls)
|
||||
if image_data is None:
|
||||
image_data = [fetch_image(url) for url in image_urls]
|
||||
@ -146,29 +179,32 @@ def run_generate(model, question: str, image_urls: List[str]):
|
||||
|
||||
|
||||
def run_chat(model: str, question: str, image_urls: List[str]):
|
||||
llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls)
|
||||
llm, _, stop_token_ids, _, chat_template = model_example_map[model](
|
||||
question, image_urls)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=128,
|
||||
stop_token_ids=stop_token_ids)
|
||||
|
||||
outputs = llm.chat([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": question,
|
||||
},
|
||||
*({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
outputs = llm.chat(
|
||||
[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": question,
|
||||
},
|
||||
} for image_url in image_urls),
|
||||
],
|
||||
}],
|
||||
sampling_params=sampling_params)
|
||||
*({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
},
|
||||
} for image_url in image_urls),
|
||||
],
|
||||
}],
|
||||
sampling_params=sampling_params,
|
||||
chat_template=chat_template,
|
||||
)
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
|
@ -1,11 +1,17 @@
|
||||
import pathlib
|
||||
from typing import List, Optional, Type
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputContext, LLMInputs
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from ..conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput,
|
||||
VllmRunner, _ImageAssets)
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
@ -23,19 +29,205 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"Picture 1: <img></img>\nWhat is the season?: ",
|
||||
})
|
||||
|
||||
HF_MULTIIMAGE_IMAGE_PROMPT = "Picture 1: <img></img>\nPicture 2: <img></img>\nCan you compare these images?\n" # noqa: E501
|
||||
HF_MULTIIMAGE_IMAGE_PROMPT = "Picture 1: <img></img>\nPicture 2: <img></img>\nDescribe the two images in detail.\n" # noqa: E501
|
||||
### Multimodal preprocessing tests
|
||||
SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image
|
||||
# These values are specific to Qwen-VL/Chat; we can get these from the model
|
||||
# config also, but they are hardcoded here to keep the parameterize/fixtures
|
||||
# easy to read.
|
||||
IMG_START_ID = 151857
|
||||
IMG_END_ID = 151858
|
||||
IMG_PAD_ID = 151859
|
||||
TOKS_PER_IMG = 256
|
||||
VIS_ENC_DIM = 4096
|
||||
IMG_SIZE = 448
|
||||
|
||||
|
||||
def build_model_context(model_name: str,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
trust_remote_code: bool = False):
|
||||
"""Creates an InputContext for a given model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model being considered.
|
||||
tokenizer_name: Name of the tokenizer being considered.
|
||||
trust_remote_code: Whether or not to allow loading remote code.
|
||||
|
||||
Returns:
|
||||
InputContext for the model being considered.
|
||||
"""
|
||||
if tokenizer_name is None:
|
||||
tokenizer_name = model_name
|
||||
model_config = ModelConfig(
|
||||
model_name,
|
||||
tokenizer_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype="float32",
|
||||
seed=0,
|
||||
)
|
||||
return InputContext(model_config)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def input_mapper_for_qwen():
|
||||
# Lazy import to avoid initializing CUDA during test collection
|
||||
from vllm.model_executor.models.qwen import input_mapper_for_qwen
|
||||
return input_mapper_for_qwen
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def input_processor_for_qwen():
|
||||
# Lazy import to avoid initializing CUDA during test collection
|
||||
from vllm.model_executor.models.qwen import input_processor_for_qwen
|
||||
return input_processor_for_qwen
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def qwen_vl_context() -> InputContext:
|
||||
"""Get an InputContext for Qwen-VL."""
|
||||
return build_model_context(model_name="Qwen/Qwen-VL",
|
||||
trust_remote_code=True)
|
||||
|
||||
|
||||
# Happy path tests for single/multi-image scenarios for the multimodal
|
||||
# input processor and mapper, respectively
|
||||
@pytest.mark.parametrize("num_images", [1, 2])
|
||||
def test_input_processor_valid_mm_data(input_processor_for_qwen,
|
||||
qwen_vl_context: InputContext,
|
||||
num_images: int):
|
||||
"""Happy cases for image inputs to Qwen's multimodal input processor."""
|
||||
prompt = "".join(
|
||||
[f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)])
|
||||
inputs = LLMInputs(
|
||||
prompt=prompt,
|
||||
# When processing multimodal data for a multimodal model, the qwen
|
||||
# input processor will overwrite the provided prompt_token_ids with
|
||||
# the image prompts
|
||||
prompt_token_ids=None,
|
||||
multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
|
||||
)
|
||||
proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
|
||||
assert isinstance(proc_inputs, dict)
|
||||
|
||||
# Each image should have one start / stop and a fixed context of 256
|
||||
proc_tokens = proc_inputs["prompt_token_ids"]
|
||||
assert proc_tokens.count(IMG_START_ID) == num_images
|
||||
assert proc_tokens.count(IMG_END_ID) == num_images
|
||||
assert proc_tokens.count(IMG_PAD_ID) == num_images * TOKS_PER_IMG
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"img_data,expected_shape",
|
||||
[
|
||||
# single / multi-image
|
||||
(SAMPLE_IMAGE, (1, 3, IMG_SIZE, IMG_SIZE)),
|
||||
(2 * [SAMPLE_IMAGE], (2, 3, IMG_SIZE, IMG_SIZE)),
|
||||
# single / multi-image embeddings
|
||||
(torch.rand(
|
||||
(TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
|
||||
(torch.rand(
|
||||
(1, TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
|
||||
(torch.rand(
|
||||
(2, TOKS_PER_IMG, VIS_ENC_DIM)), (2, TOKS_PER_IMG, VIS_ENC_DIM)),
|
||||
])
|
||||
def test_input_mapper_valid_mm_data(input_mapper_for_qwen,
|
||||
qwen_vl_context: InputContext,
|
||||
img_data: Union[torch.Tensor, List[Image],
|
||||
Image],
|
||||
expected_shape: List[int]):
|
||||
"""Happy cases for image inputs to Qwen's multimodal input mapper."""
|
||||
mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data)
|
||||
# Ensure that we get the appropriately shaped pixel_values
|
||||
# for images and image embeddings, respectively.
|
||||
assert isinstance(mapped_img_data, MultiModalInputs)
|
||||
assert "pixel_values" in mapped_img_data
|
||||
assert mapped_img_data["pixel_values"].shape == expected_shape
|
||||
|
||||
|
||||
# Sad path tests for the multimodal input processor and mapper, respectively
|
||||
@pytest.mark.parametrize("mm_data", [
|
||||
{
|
||||
"image": torch.rand((5))
|
||||
},
|
||||
{
|
||||
"image": torch.rand((5, 5, 5, 5, 5))
|
||||
},
|
||||
])
|
||||
def test_input_processor_invalid_mm_data(input_processor_for_qwen,
|
||||
qwen_vl_context: InputContext,
|
||||
mm_data: Dict[str, torch.Tensor]):
|
||||
"""Test sad cases validated in Qwen's multimodal input processor."""
|
||||
tokenizer = cached_get_tokenizer(qwen_vl_context.model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
prompt = "Picture 1: <img></img>\n"
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
inputs = LLMInputs(prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=mm_data)
|
||||
# Should fail since we have too many or too few dimensions for embeddings
|
||||
with pytest.raises(ValueError):
|
||||
input_processor_for_qwen(qwen_vl_context, inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"img_data",
|
||||
[
|
||||
# Wrong context length
|
||||
torch.rand((1, TOKS_PER_IMG + 10, VIS_ENC_DIM)),
|
||||
# Wrong visual encoder output size
|
||||
torch.rand((1, TOKS_PER_IMG, VIS_ENC_DIM + 10)),
|
||||
])
|
||||
def test_input_mapper_invalid_mm_data(
|
||||
input_mapper_for_qwen,
|
||||
qwen_vl_context: InputContext,
|
||||
img_data: Union[torch.Tensor, List[Image], Image],
|
||||
):
|
||||
"""Sad cases validated in Qwen VL's multimodal input mapper."""
|
||||
with pytest.raises(ValueError):
|
||||
input_mapper_for_qwen(qwen_vl_context, img_data)
|
||||
|
||||
|
||||
### End-to-end generation tests
|
||||
def get_prompt_with_path(tmp_path: pathlib.PosixPath, prompt: str,
|
||||
assets: Union[_ImageAssets, List[ImageAsset]]) -> str:
|
||||
"""Given a temporary dir path, export one or more image assets into the
|
||||
tempdir & replace its contents with the local path to the string so that
|
||||
the HF version of Qwen-VL can resolve the path and load the image ni its
|
||||
forward() call.
|
||||
|
||||
Args:
|
||||
tmp_path: Tempdir for test under consideration.
|
||||
prompt: Prompt with image placeholders.
|
||||
assets: List of image assets whose len equals the num placeholders.
|
||||
"""
|
||||
# Ensure that the number of placeholders matches the number of assets;
|
||||
# If this is not true, the test is probably written incorrectly.
|
||||
assert prompt.count("<img></img>") == len(assets)
|
||||
|
||||
# Replace the placeholders with local paths to the exported assets
|
||||
for asset in assets:
|
||||
image_tmp_path = tmp_path / f"{asset.name}.jpg"
|
||||
asset.pil_image.save(image_tmp_path)
|
||||
prompt = prompt.replace(
|
||||
"<img></img>",
|
||||
f"<img>{image_tmp_path}</img>",
|
||||
1,
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
### Tests for multimodal Qwen models
|
||||
def run_test(
|
||||
tmp_path: pathlib.PosixPath,
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
inputs: List[Tuple[List[str], PromptImageInput]],
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
mm_limit: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
@ -48,23 +240,6 @@ def run_test(
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
# Export the images to a tempdir and substitute it into the hf prompt;
|
||||
# the contents between <img>/</img> will be ignored by VLLM, but the
|
||||
# transformers implementation for the visual transformer parses this to
|
||||
# reload it in the forward call; the contents are treated as a URL or a
|
||||
# local path.
|
||||
for idx, asset in enumerate(image_assets):
|
||||
image_tmp_path = tmp_path / f"{asset.name}.jpg"
|
||||
asset.pil_image.save(image_tmp_path)
|
||||
HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace(
|
||||
"<img></img>", f"<img>{image_tmp_path}</img>")
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
@ -72,11 +247,12 @@ def run_test(
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
# Qwen encodes images into a fixed content size of 256
|
||||
# Qwen encodes each image into a fixed content size of 256
|
||||
with vllm_runner(model,
|
||||
max_model_len=300,
|
||||
max_model_len=1024,
|
||||
max_num_seqs=1,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": mm_limit},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
@ -85,7 +261,7 @@ def run_test(
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
@ -94,7 +270,7 @@ def run_test(
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
@ -125,19 +301,81 @@ def run_test(
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [8])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets,
|
||||
model, size_factors, dtype, max_tokens,
|
||||
num_logprobs) -> None:
|
||||
def test_multimodal_models_single_image(tmp_path: pathlib.PosixPath,
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets, model: str,
|
||||
size_factors: List[float], dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
"""Tests multimodal models with single image prompts."""
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
prompts = [
|
||||
get_prompt_with_path(tmp_path, prompt, [asset])
|
||||
for prompt, asset in zip(HF_IMAGE_PROMPTS, image_assets)
|
||||
]
|
||||
|
||||
inputs = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, prompts)]
|
||||
|
||||
run_test(
|
||||
tmp_path,
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
inputs,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", multimodal_models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_multimodal_models_multi_image(tmp_path: pathlib.PosixPath,
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets, model: str,
|
||||
size_factors: List[float], dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
"""Tests multimodal models with multi-image prompts."""
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
# Put all of the images into one prompt.
|
||||
prompt = get_prompt_with_path(tmp_path, HF_MULTIIMAGE_IMAGE_PROMPT,
|
||||
image_assets)
|
||||
inputs = [([prompt for _ in size_factors],
|
||||
[[rescale_image_size(image, factor) for image in images]
|
||||
for factor in size_factors])]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=2,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
@ -150,7 +388,7 @@ def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets,
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_text_only_qwen_model_can_be_loaded_and_run(
|
||||
vllm_runner: Type[VllmRunner],
|
||||
example_prompts,
|
||||
example_prompts: List[str],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
|
@ -47,6 +47,7 @@ from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .utils import flatten_bn, is_pp_missing_parameter, make_layers
|
||||
|
||||
@ -684,9 +685,12 @@ def input_processor_for_qwen(ctx: InputContext,
|
||||
raise ValueError(
|
||||
f"Expected img embeds to be have 3 dimensions, got {num_dims}")
|
||||
num_images = 1 if num_dims == 2 else image_data.shape[0]
|
||||
else:
|
||||
# TODO - handle multiple image inputs once the API is solidified
|
||||
elif isinstance(image_data, Image.Image):
|
||||
num_images = 1
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
num_images = len(image_data)
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
@ -767,11 +771,11 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
|
||||
f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
|
||||
f"received shape [{data.shape}]")
|
||||
pixel_values = data
|
||||
|
||||
else:
|
||||
transform = build_normalization_transform(image_size)
|
||||
# TODO - handle multiple image inputs once the API is solidified
|
||||
transformed_images = [transform(data)]
|
||||
if not isinstance(data, (list, tuple)):
|
||||
data = [data]
|
||||
transformed_images = [transform(datum) for datum in data]
|
||||
pixel_values = torch.stack(transformed_images, dim=0)
|
||||
return MultiModalInputs({"pixel_values": pixel_values})
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user