[Model] Refactor Qwen2-VL to use merged multimodal processor (#11258)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
7379b3d4b2
commit
e24113a8fe
@ -447,7 +447,6 @@ def run_qwen_vl(question: str, modality: str):
|
||||
|
||||
# Qwen2-VL
|
||||
def run_qwen2_vl(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
@ -463,8 +462,13 @@ def run_qwen2_vl(question: str, modality: str):
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
if modality == "image":
|
||||
placeholder = "<|image_pad|>"
|
||||
elif modality == "video":
|
||||
placeholder = "<|video_pad|>"
|
||||
|
||||
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
||||
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
stop_token_ids = None
|
||||
|
@ -1,12 +1,9 @@
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.inputs import InputContext, token_inputs
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.inputs import InputContext, InputProcessingContext
|
||||
|
||||
from .....conftest import _ImageAssets
|
||||
from ....utils import build_model_context
|
||||
@ -20,22 +17,9 @@ MAX_PIXELS = "max_pixels"
|
||||
# NOTE: Qwen2VL supports multiple input modalities, so it registers multiple
|
||||
# input mappers.
|
||||
@pytest.fixture()
|
||||
def image_input_mapper_for_qwen2_vl():
|
||||
from vllm.model_executor.models.qwen2_vl import (
|
||||
image_input_mapper_for_qwen2_vl)
|
||||
return image_input_mapper_for_qwen2_vl
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def input_processor_for_qwen2_vl():
|
||||
from vllm.model_executor.models.qwen2_vl import (
|
||||
input_processor_for_qwen2_vl)
|
||||
return input_processor_for_qwen2_vl
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def qwen2_vl_context() -> InputContext:
|
||||
return build_model_context(model_name=MODEL)
|
||||
def processor_for_qwen2_vl():
|
||||
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
|
||||
return Qwen2VLMultiModalProcessor
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -45,12 +29,6 @@ def get_max_qwen2_vl_image_tokens():
|
||||
return get_max_qwen2_vl_image_tokens
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_data_for_qwen2_vl():
|
||||
from vllm.model_executor.models.qwen2_vl import dummy_data_for_qwen2_vl
|
||||
return dummy_data_for_qwen2_vl
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
|
||||
({}, 1225),
|
||||
({
|
||||
@ -58,110 +36,70 @@ def dummy_data_for_qwen2_vl():
|
||||
MAX_PIXELS: 512**2
|
||||
}, 324),
|
||||
])
|
||||
def test_qwen2_vl_max_image_tokens(get_max_qwen2_vl_image_tokens,
|
||||
qwen2_vl_context: InputContext,
|
||||
mm_processor_kwargs: Dict[str, Any],
|
||||
expected_max_tokens: int):
|
||||
@pytest.mark.parametrize("model", [MODEL])
|
||||
def test_qwen2_vl_max_image_tokens(
|
||||
get_max_qwen2_vl_image_tokens,
|
||||
model: str,
|
||||
mm_processor_kwargs: Dict[str, Any],
|
||||
expected_max_tokens: int,
|
||||
):
|
||||
"""Ensure that the max token calc handles min/max pixels properly."""
|
||||
actual_max_tokens = get_max_qwen2_vl_image_tokens(qwen2_vl_context,
|
||||
**mm_processor_kwargs)
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
|
||||
actual_max_tokens = get_max_qwen2_vl_image_tokens(
|
||||
InputContext(ctx.model_config), **mm_processor_kwargs)
|
||||
assert actual_max_tokens == expected_max_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mm_processor_kwargs,token_count,img_size", [
|
||||
[{}, 1225, (980, 980)],
|
||||
[{
|
||||
MIN_PIXELS: 64**2,
|
||||
MAX_PIXELS: 512**2
|
||||
}, 324, (504, 504)],
|
||||
])
|
||||
def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl,
|
||||
qwen2_vl_context: InputContext,
|
||||
mm_processor_kwargs: Dict[str, Any],
|
||||
token_count: int, img_size: Tuple[int, int]):
|
||||
"""Ensure that the dummy data handles min/max pixels properly."""
|
||||
seq_len = 3000
|
||||
hf_config = qwen2_vl_context.get_hf_config()
|
||||
image_token_id = hf_config.image_token_id
|
||||
|
||||
# NOTE: video value is required, but isn't actually used
|
||||
# when making the dummy data except for error handling currently
|
||||
dummy_data = dummy_data_for_qwen2_vl(
|
||||
ctx=qwen2_vl_context,
|
||||
seq_len=seq_len,
|
||||
mm_counts={
|
||||
"image": 1,
|
||||
"video": 0
|
||||
},
|
||||
**mm_processor_kwargs,
|
||||
@pytest.mark.parametrize(
|
||||
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [
|
||||
({}, 1426, (5704, 1176)),
|
||||
({
|
||||
MIN_PIXELS: 64**2,
|
||||
MAX_PIXELS: 512**2
|
||||
}, 330, (1320, 1176)),
|
||||
])
|
||||
@pytest.mark.parametrize("model", [MODEL])
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_processor_override(
|
||||
processor_for_qwen2_vl,
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
mm_processor_kwargs: Dict[str, Any],
|
||||
expected_toks_per_img: int,
|
||||
expected_pixels_shape: Tuple[int, int],
|
||||
num_imgs: int,
|
||||
):
|
||||
"""Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
# the partial when calling the custom input processor.
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
seq_data = dummy_data.seq_data
|
||||
mm_data = dummy_data.multi_modal_data
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
||||
# Build the image str / prompt based on the number of images we pass
|
||||
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
|
||||
images = [image_assets[0].pil_image] * num_imgs
|
||||
|
||||
# Ensure we have the right number of placeholders for min/max pixel values
|
||||
assert seq_data.get_token_ids().count(image_token_id) == token_count
|
||||
mm_data = {"image": images}
|
||||
|
||||
# Ensure the images were resized correctly
|
||||
image = mm_data["image"]
|
||||
assert isinstance(image, Image)
|
||||
assert image.size == img_size
|
||||
processor = processor_for_qwen2_vl(ctx)
|
||||
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
|
||||
# Ensure we have the right number of placeholders per num_crops size
|
||||
hf_processor = processor._get_hf_processor(**mm_processor_kwargs)
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
|
||||
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
|
||||
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape
|
||||
|
||||
@pytest.mark.parametrize("mm_processor_kwargs,num_placeholders", [
|
||||
({}, 1426),
|
||||
({
|
||||
MIN_PIXELS: 64**2,
|
||||
MAX_PIXELS: 512**2
|
||||
}, 330),
|
||||
])
|
||||
def test_input_processor(input_processor_for_qwen2_vl,
|
||||
qwen2_vl_context: InputContext,
|
||||
image_assets: _ImageAssets, num_placeholders: int,
|
||||
mm_processor_kwargs: Dict[str, Any]):
|
||||
"""Ensure that the image processor handles min/max pixels properly."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
||||
prompt = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
|
||||
image = image_assets[0].pil_image
|
||||
hf_config = qwen2_vl_context.get_hf_config()
|
||||
image_token_id = hf_config.image_token_id
|
||||
|
||||
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
|
||||
prompt=prompt,
|
||||
multi_modal_data={"image": [image]})
|
||||
|
||||
processed_inputs = input_processor_for_qwen2_vl(qwen2_vl_context, inputs,
|
||||
**mm_processor_kwargs)
|
||||
assert processed_inputs["prompt_token_ids"].count(
|
||||
image_token_id) == num_placeholders
|
||||
assert len(processed_inputs["multi_modal_data"]["image"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mm_processor_kwargs,pixels_shape", [
|
||||
({}, [5704, 1176]),
|
||||
({
|
||||
MIN_PIXELS: 64**2,
|
||||
MAX_PIXELS: 512**2
|
||||
}, [1320, 1176]),
|
||||
])
|
||||
def test_image_mapper_override(qwen2_vl_context: InputContext,
|
||||
image_assets: _ImageAssets,
|
||||
mm_processor_kwargs: Dict[str, Any],
|
||||
pixels_shape: Tuple[int, int]):
|
||||
"""Ensure that the image mapper handles min/max pixels properly."""
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(qwen2_vl_context.model_config)
|
||||
|
||||
image = image_assets[0].pil_image
|
||||
|
||||
mapped_output = mm_registry.map_input(
|
||||
qwen2_vl_context.model_config,
|
||||
{"image": image},
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
# Dimension 0 of pixel values should match the product of image_grid_thw
|
||||
actual_pixels_shape = mapped_output["pixel_values"].shape
|
||||
assert list(actual_pixels_shape) == pixels_shape
|
||||
assert actual_pixels_shape[0] == torch.prod(
|
||||
mapped_output["image_grid_thw"])
|
||||
assert img_tok_count == expected_toks_per_img * num_imgs
|
||||
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
|
||||
assert pixel_shape[1] == expected_pixels_shape[1]
|
||||
|
@ -164,7 +164,9 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
audio_len = get_max_qwen2_audio_audio_tokens(self.ctx)
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
|
||||
audio_count = mm_counts["audio"]
|
||||
audio = np.zeros(audio_len)
|
||||
|
@ -22,28 +22,26 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
from functools import cached_property, partial
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
||||
Optional, Set, Tuple, Type, TypedDict, Union)
|
||||
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, Type, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image
|
||||
from transformers.image_utils import (get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
to_numpy_array)
|
||||
from transformers import BatchFeature
|
||||
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
|
||||
Qwen2VLProcessor)
|
||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
||||
Qwen2VLConfig, Qwen2VLVisionConfig)
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
||||
make_batched_images, make_batched_videos, smart_resize)
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
@ -56,14 +54,14 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, NestedTensors
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
|
||||
@ -159,7 +157,7 @@ class Qwen2VisionMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int = None,
|
||||
hidden_features: int,
|
||||
act_layer: Type[nn.Module] = QuickGELU,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -644,78 +642,8 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
# === Vision input helpers === #
|
||||
|
||||
|
||||
def get_mm_processor_kwargs(
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None) -> Dict[str, int]:
|
||||
mm_processor_kwargs = {}
|
||||
if min_pixels:
|
||||
mm_processor_kwargs["min_pixels"] = min_pixels
|
||||
if max_pixels:
|
||||
mm_processor_kwargs["max_pixels"] = max_pixels
|
||||
return mm_processor_kwargs
|
||||
|
||||
|
||||
def mm_input_mapper_for_qwen2_vl(
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[object],
|
||||
data_type_key: str,
|
||||
*,
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
) -> MultiModalKwargs:
|
||||
"""Input mapper for Qwen2-VL."""
|
||||
if data_type_key == "image" and isinstance(data, dict):
|
||||
return MultiModalKwargs({
|
||||
"image_embeds": data.get("image_embeds"),
|
||||
"image_grid_thw": data.get("image_grid_thw"),
|
||||
})
|
||||
if data_type_key == "video" and isinstance(data, dict):
|
||||
return MultiModalKwargs({
|
||||
"video_embeds": data.get("video_embeds"),
|
||||
"video_grid_thw": data.get("video_grid_thw"),
|
||||
})
|
||||
|
||||
model_config = ctx.model_config
|
||||
# Handle mm processor kwargs; we pass these at creation time
|
||||
# because preprocess() in transformers doesn't expose them
|
||||
mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels,
|
||||
max_pixels=max_pixels)
|
||||
image_processor = cached_get_image_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
if image_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available "
|
||||
"to process the image object")
|
||||
|
||||
images = None
|
||||
videos = None
|
||||
if data_type_key == "image":
|
||||
images = data
|
||||
else:
|
||||
assert data_type_key == "video"
|
||||
videos = data
|
||||
|
||||
try:
|
||||
batch_data = image_processor \
|
||||
.preprocess(images=images, videos=videos, return_tensors="pt") \
|
||||
.data
|
||||
except Exception:
|
||||
logger.error("Failed to process image (%s)", data)
|
||||
raise
|
||||
|
||||
return MultiModalKwargs(batch_data)
|
||||
|
||||
|
||||
image_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl,
|
||||
data_type_key="image")
|
||||
video_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl,
|
||||
data_type_key="video")
|
||||
|
||||
|
||||
def _get_vision_info(
|
||||
image_processor,
|
||||
vision_config: Qwen2VLVisionConfig,
|
||||
height: int,
|
||||
width: int,
|
||||
min_pixels: int,
|
||||
@ -726,12 +654,15 @@ def _get_vision_info(
|
||||
):
|
||||
"""Get information (resized height / width and number of vision tokens)
|
||||
of input image / video frame."""
|
||||
patch_size = vision_config.patch_size
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
temporal_patch_size = vision_config.temporal_patch_size
|
||||
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=height,
|
||||
width=width,
|
||||
factor=image_processor.patch_size * image_processor.merge_size,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
@ -742,54 +673,41 @@ def _get_vision_info(
|
||||
grid_t = mm_count
|
||||
else:
|
||||
assert data_type_key == "video"
|
||||
grid_t = max(mm_count // image_processor.temporal_patch_size, 1)
|
||||
grid_t = max(mm_count // temporal_patch_size, 1)
|
||||
|
||||
grid_h = resized_height // image_processor.patch_size
|
||||
grid_w = resized_width // image_processor.patch_size
|
||||
grid_h = resized_height // patch_size
|
||||
grid_w = resized_width // patch_size
|
||||
vision_tokens = grid_t * grid_h * grid_w
|
||||
llm_num_vision_tokens = (vision_tokens // image_processor.merge_size //
|
||||
image_processor.merge_size)
|
||||
llm_num_vision_tokens = vision_tokens // (merge_size**2)
|
||||
|
||||
return resized_height, resized_width, llm_num_vision_tokens
|
||||
|
||||
|
||||
def _get_max_image_info(
|
||||
image_processor,
|
||||
data_type_key: str = "image",
|
||||
mm_count: int = 1,
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
):
|
||||
# Limit min / max pixels unless they're explicitly provided
|
||||
if min_pixels is None:
|
||||
min_pixels = max(image_processor.min_pixels, 28 * 28)
|
||||
if max_pixels is None:
|
||||
max_pixels = min(image_processor.max_pixels, 1280 * 28 * 28)
|
||||
|
||||
return _get_vision_info(
|
||||
image_processor,
|
||||
height=9999999,
|
||||
width=9999999,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
data_type_key=data_type_key,
|
||||
mm_count=mm_count,
|
||||
)
|
||||
def _get_image_processor(hf_processor: Qwen2VLProcessor):
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
assert isinstance(image_processor, Qwen2VLImageProcessor)
|
||||
return image_processor
|
||||
|
||||
|
||||
def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
|
||||
data_type_key: str,
|
||||
*,
|
||||
min_pixels=None,
|
||||
max_pixels=None) -> int:
|
||||
mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels,
|
||||
max_pixels=max_pixels)
|
||||
image_processor = cached_get_image_processor(ctx.model_config.model,
|
||||
**mm_processor_kwargs)
|
||||
max_resized_height, max_resized_width, max_llm_image_tokens = \
|
||||
_get_max_image_info(image_processor, data_type_key=data_type_key,
|
||||
mm_count=1, min_pixels=min_pixels,
|
||||
max_pixels=max_pixels)
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None) -> int:
|
||||
hf_config = ctx.get_hf_config(Qwen2VLConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
hf_processor = ctx.get_hf_processor(Qwen2VLProcessor)
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
|
||||
_, _, max_llm_image_tokens = _get_vision_info(
|
||||
vision_config,
|
||||
height=9999999,
|
||||
width=9999999,
|
||||
min_pixels=min_pixels or image_processor.min_pixels,
|
||||
max_pixels=max_pixels or image_processor.max_pixels,
|
||||
data_type_key=data_type_key,
|
||||
)
|
||||
return max_llm_image_tokens
|
||||
|
||||
|
||||
@ -799,290 +717,166 @@ get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
|
||||
data_type_key="video")
|
||||
|
||||
|
||||
def dummy_data_for_qwen2_vl(
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
*,
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None
|
||||
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
|
||||
mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels,
|
||||
max_pixels=max_pixels)
|
||||
image_processor = cached_get_image_processor(ctx.model_config.model,
|
||||
**mm_processor_kwargs)
|
||||
class Qwen2VLMultiModalDataItems(MultiModalDataItems):
|
||||
|
||||
num_images = mm_counts["image"]
|
||||
max_resized_height, max_resized_width, max_llm_image_tokens = \
|
||||
_get_max_image_info(image_processor, data_type_key="image",
|
||||
mm_count=num_images, min_pixels=min_pixels,
|
||||
max_pixels=max_pixels)
|
||||
if seq_len - max_llm_image_tokens - 2 < 0:
|
||||
raise RuntimeError(
|
||||
f"Qwen2-VL cannot process {num_images} images in a prompt, "
|
||||
"please increase max_model_len or reduce image limit by "
|
||||
"--limit-mm-per-prompt.")
|
||||
@staticmethod
|
||||
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
|
||||
"""
|
||||
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
|
||||
"""
|
||||
multi_data = Qwen2VLMultiModalDataItems()
|
||||
|
||||
# Check video counts.
|
||||
num_videos = mm_counts["video"]
|
||||
max_resized_height, max_resized_width, max_llm_video_tokens = \
|
||||
_get_max_image_info(image_processor, data_type_key="video",
|
||||
mm_count=num_videos, min_pixels=min_pixels,
|
||||
max_pixels=max_pixels)
|
||||
if seq_len - max_llm_video_tokens - 2 < 0:
|
||||
raise RuntimeError(
|
||||
f"Qwen2-VL cannot process {num_videos} videos in a prompt, "
|
||||
"please increase max_model_len or reduce video limit by "
|
||||
"--limit-mm-per-prompt.")
|
||||
for k, v in data.items():
|
||||
# TODO: Make a separate modality for embedding inputs
|
||||
# to avoid confusion
|
||||
# yapf: disable
|
||||
if k == "video":
|
||||
# Special case since even a single item can be a list
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if (isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
|
||||
or is_list_of(v, list)) else [v]
|
||||
)
|
||||
elif k in ("image", "audio"):
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if isinstance(v, (dict, torch.Tensor, list)) else [v]
|
||||
)
|
||||
else:
|
||||
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
|
||||
# yapf: enable
|
||||
|
||||
hf_config = ctx.get_hf_config(Qwen2VLConfig)
|
||||
return multi_data
|
||||
|
||||
dummy_seqdata = SequenceData.from_prompt_token_counts(
|
||||
(hf_config.vision_start_token_id, 1),
|
||||
(hf_config.image_token_id, max_llm_image_tokens),
|
||||
(hf_config.vision_end_token_id, 1),
|
||||
(0, seq_len - max_llm_image_tokens - 2),
|
||||
)
|
||||
|
||||
dummy_image = Image.new("RGB", (max_resized_width, max_resized_height),
|
||||
color=0)
|
||||
|
||||
return DummyData(dummy_seqdata, {
|
||||
"image":
|
||||
dummy_image if num_images == 1 else [dummy_image] * num_images
|
||||
})
|
||||
def get_item_counts(self) -> Mapping[str, int]:
|
||||
return {
|
||||
m: (
|
||||
len(items[f"{m}_grid_thw"]) # type: ignore
|
||||
if isinstance(items, dict) else len(items))
|
||||
for m, items in self.items()
|
||||
}
|
||||
|
||||
|
||||
def _get_llm_num_vision_tokens(
|
||||
mm_inputs: list,
|
||||
data_type_key: str,
|
||||
image_processor,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
):
|
||||
"""Get number of vision tokens of multimodal inputs.
|
||||
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
This method is derived from `transformers.models.qwen2_vl.
|
||||
image_processing_qwen2_vl.Qwen2VLImageProcessor._preprocess`.
|
||||
"""
|
||||
image = to_numpy_array(mm_inputs[0])
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
height, width = get_image_size(image, channel_dim=input_data_format)
|
||||
def _get_mm_items(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> MultiModalDataItems:
|
||||
return Qwen2VLMultiModalDataItems.from_dict(mm_data)
|
||||
|
||||
_, _, llm_num_vision_tokens = _get_vision_info(
|
||||
image_processor,
|
||||
height=height,
|
||||
width=width,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
do_resize=image_processor.do_resize,
|
||||
data_type_key=data_type_key,
|
||||
mm_count=len(mm_inputs),
|
||||
)
|
||||
return llm_num_vision_tokens
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
) -> Qwen2VLProcessor:
|
||||
hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
|
||||
if min_pixels:
|
||||
image_processor.min_pixels = min_pixels
|
||||
if max_pixels:
|
||||
image_processor.max_pixels = max_pixels
|
||||
if max_pixels or min_pixels:
|
||||
image_processor.size = {
|
||||
"min_pixels": image_processor.min_pixels,
|
||||
"max_pixels": image_processor.max_pixels,
|
||||
}
|
||||
|
||||
def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
|
||||
data_type_key: str, image_processor: Any,
|
||||
prompt_token_ids: List[int], min_pixels: Optional[int],
|
||||
max_pixels: Optional[int]) -> List[int]:
|
||||
"""
|
||||
Expand pad tokens for multi-modal inputs (e.g., images or videos).
|
||||
return hf_processor
|
||||
|
||||
Args:
|
||||
inputs (list): The multi-modal inputs (e.g., images or videos).
|
||||
token_id (int): The token ID used to represent the multi-modal input.
|
||||
make_batched_fn (Callable): A function to batch the inputs.
|
||||
data_type_key (str): The type of the multi-modal input.
|
||||
image_processor (Any): The image processor used to process the inputs.
|
||||
prompt_token_ids (List[int]): The list of token IDs in the prompt.
|
||||
min_pixels (int): min pixels to used for img processing
|
||||
max_pixels (int): max pixels to be used for img processing
|
||||
def _get_processor_data(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
processor_data = dict[str, Any]()
|
||||
passthrough_data = dict[str, Any]()
|
||||
|
||||
Returns:
|
||||
List[int]: The list of token IDs for the multi-modal inputs.
|
||||
"""
|
||||
indices = [
|
||||
idx for idx, token in enumerate(prompt_token_ids) if token == token_id
|
||||
]
|
||||
inputs = make_batched_fn(inputs)
|
||||
assert len(indices) == len(inputs)
|
||||
for k, v in mm_items.items():
|
||||
# TODO: Make a separate modality for embedding inputs
|
||||
# to avoid confusion
|
||||
if k in ("image", "video", "audio"):
|
||||
if isinstance(v, dict):
|
||||
# Pass through embedding inputs (dict)
|
||||
passthrough_data.update(v)
|
||||
elif isinstance(v, torch.Tensor) and v.ndim == 3:
|
||||
# Pass through embedding inputs (single)
|
||||
passthrough_data[f"{k}_embeds"] = [v]
|
||||
elif (is_list_of(v, torch.Tensor) and len(v) > 0
|
||||
and v[0].ndim == 2):
|
||||
# Pass through embedding inputs (multi)
|
||||
passthrough_data[f"{k}_embeds"] = v
|
||||
else:
|
||||
# Map keys to plural form, e.g.: image -> images
|
||||
processor_data[f"{k}s"] = v
|
||||
else:
|
||||
processor_data[k] = v
|
||||
|
||||
prompt_token_ids_with_data = []
|
||||
for cnt, data in enumerate(inputs):
|
||||
num_tokens = _get_llm_num_vision_tokens(
|
||||
[data] if data_type_key == "image" else data,
|
||||
data_type_key=data_type_key,
|
||||
image_processor=image_processor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
return processor_data, passthrough_data
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
|
||||
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
|
||||
# image_token and video_token registered
|
||||
placeholder = {
|
||||
"image": hf_processor.image_token,
|
||||
"video": hf_processor.video_token,
|
||||
}
|
||||
merge_length = image_processor.merge_size**2
|
||||
|
||||
def get_replacement_qwen2vl(item_idx: int, modality: str):
|
||||
grid_thw = hf_inputs[f"{modality}_grid_thw"][item_idx]
|
||||
num_tokens = grid_thw.prod() // merge_length
|
||||
return placeholder[modality] * num_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality=modality,
|
||||
target=placeholder[modality],
|
||||
replacement=partial(get_replacement_qwen2vl,
|
||||
modality=modality),
|
||||
) for modality in ("image", "video")
|
||||
]
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts["image"]
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_token: str = hf_processor.image_token
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
|
||||
data = {}
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=9999999,
|
||||
width=9999999,
|
||||
factor=image_processor.patch_size * image_processor.merge_size,
|
||||
min_pixels=image_processor.min_pixels,
|
||||
max_pixels=image_processor.max_pixels,
|
||||
)
|
||||
|
||||
dummy_image = Image.new("RGB", (resized_width, resized_height),
|
||||
color=0)
|
||||
data["image"] = [dummy_image] * num_images
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
if cnt == 0:
|
||||
end_idx = indices[cnt]
|
||||
non_data_tokens = prompt_token_ids[:end_idx]
|
||||
else:
|
||||
non_data_tokens = prompt_token_ids[indices[cnt - 1] +
|
||||
1:indices[cnt]]
|
||||
prompt_token_ids_with_data.extend(non_data_tokens)
|
||||
prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens))
|
||||
prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:])
|
||||
return prompt_token_ids_with_data
|
||||
|
||||
|
||||
def input_processor_for_qwen2_vl(
|
||||
ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None:
|
||||
return inputs
|
||||
|
||||
image_inputs = multi_modal_data.get("image", None)
|
||||
video_inputs = multi_modal_data.get("video", None)
|
||||
|
||||
processor = cached_get_processor(ctx.model_config.model)
|
||||
image_processor = processor.image_processor
|
||||
# Apply processor kwarg overrides for image processor options
|
||||
min_pixels = min_pixels if min_pixels else image_processor.min_pixels
|
||||
max_pixels = max_pixels if max_pixels else image_processor.max_pixels
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(Qwen2VLConfig)
|
||||
|
||||
# To avoid redundant processing of vision objects (resize, rescale, etc.),
|
||||
# we extract code of calculating number of vision tokens from
|
||||
# `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`.
|
||||
#
|
||||
# The following code is equivalent to:
|
||||
# prompt = inputs["prompt"]
|
||||
# inputs = processor(text=[prompt],
|
||||
# images=image_inputs,
|
||||
# videos=video_inputs,
|
||||
# padding=True,
|
||||
# return_tensors="pt")
|
||||
# prompt_token_ids = inputs["input_ids"][0].tolist()
|
||||
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
# Expand image pad tokens.
|
||||
|
||||
if image_inputs is not None:
|
||||
if isinstance(image_inputs, dict):
|
||||
prompt_token_ids_with_image = []
|
||||
image_indices = [
|
||||
idx for idx, token in enumerate(prompt_token_ids)
|
||||
if token == hf_config.image_token_id
|
||||
]
|
||||
|
||||
# ensure all image tokens have grid_thw
|
||||
assert \
|
||||
len(image_indices) == image_inputs["image_grid_thw"].size(0), \
|
||||
"image token num does not match image_grid_thw.shape"
|
||||
|
||||
image_counter = 0
|
||||
pad_token_counter = 0
|
||||
for idx, token in enumerate(prompt_token_ids):
|
||||
if idx in image_indices:
|
||||
grid_thw = image_inputs["image_grid_thw"][image_counter]
|
||||
grid_t, grid_h, grid_w = grid_thw
|
||||
num_pad_tokens = (grid_t * grid_h * grid_w //
|
||||
image_processor.merge_size //
|
||||
image_processor.merge_size)
|
||||
prompt_token_ids_with_image.extend([token] *
|
||||
num_pad_tokens)
|
||||
image_counter += 1
|
||||
pad_token_counter += num_pad_tokens
|
||||
else:
|
||||
prompt_token_ids_with_image.append(token)
|
||||
|
||||
# ensure all embeddings are used
|
||||
assert \
|
||||
pad_token_counter == image_inputs["image_embeds"].size(0), \
|
||||
"image_embeds.shape does not match image_grid_thw"
|
||||
|
||||
prompt_token_ids = prompt_token_ids_with_image
|
||||
else:
|
||||
prompt_token_ids = _expand_pad_tokens(image_inputs,
|
||||
hf_config.image_token_id,
|
||||
make_batched_images,
|
||||
"image",
|
||||
image_processor,
|
||||
prompt_token_ids,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels)
|
||||
|
||||
if video_inputs is not None:
|
||||
if isinstance(video_inputs, dict):
|
||||
prompt_token_ids_with_video = []
|
||||
video_indices = [
|
||||
idx for idx, token in enumerate(prompt_token_ids)
|
||||
if token == hf_config.video_token_id
|
||||
]
|
||||
|
||||
# ensure all video tokens have grid_thw
|
||||
assert \
|
||||
len(video_indices) == video_inputs["video_grid_thw"].size(0), \
|
||||
"video token num does not match video_grid_thw.shape"
|
||||
|
||||
video_counter = 0
|
||||
pad_token_counter = 0
|
||||
for idx, token in enumerate(prompt_token_ids):
|
||||
if idx in video_indices:
|
||||
grid_thw = video_inputs["video_grid_thw"][video_counter]
|
||||
grid_t, grid_h, grid_w = grid_thw
|
||||
num_pad_tokens = (grid_t * grid_h * grid_w //
|
||||
image_processor.merge_size //
|
||||
image_processor.merge_size)
|
||||
prompt_token_ids_with_video.extend([token] *
|
||||
num_pad_tokens)
|
||||
video_counter += 1
|
||||
pad_token_counter += num_pad_tokens
|
||||
else:
|
||||
prompt_token_ids_with_video.append(token)
|
||||
|
||||
# ensure all embeddings are used
|
||||
assert \
|
||||
pad_token_counter == video_inputs["video_embeds"].size(0), \
|
||||
"video_embeds.shape does not match video_grid_thw"
|
||||
|
||||
prompt_token_ids = prompt_token_ids_with_video
|
||||
else:
|
||||
prompt_token_ids = _expand_pad_tokens(video_inputs,
|
||||
hf_config.video_token_id,
|
||||
make_batched_videos,
|
||||
"video",
|
||||
image_processor,
|
||||
prompt_token_ids,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels)
|
||||
|
||||
prompt = inputs.get("prompt")
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt=prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(
|
||||
image_input_mapper_for_qwen2_vl)
|
||||
@MULTIMODAL_REGISTRY.register_input_mapper("video",
|
||||
video_input_mapper_for_qwen2_vl)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"video", get_max_qwen2_vl_video_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor)
|
||||
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
@ -1110,7 +904,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
config: Qwen2VLConfig = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
@ -220,15 +220,18 @@ class MultiModalDataItems(UserDict[str, list[Any]]):
|
||||
multi_data = MultiModalDataItems()
|
||||
|
||||
for k, v in data.items():
|
||||
# TODO: Make a separate modality for embedding inputs
|
||||
# to avoid confusion
|
||||
# yapf: disable
|
||||
if k == "video":
|
||||
# Special case since even a single item can be a list
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if is_list_of(v, (list, torch.Tensor)) else [v]
|
||||
v if (isinstance(v, torch.Tensor)
|
||||
or is_list_of(v, list)) else [v]
|
||||
)
|
||||
elif k in ("image", "audio"):
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if isinstance(v, (list, torch.Tensor)) else [v]
|
||||
v if isinstance(v, (torch.Tensor, list)) else [v]
|
||||
)
|
||||
else:
|
||||
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
|
||||
@ -252,6 +255,9 @@ class MultiModalDataItems(UserDict[str, list[Any]]):
|
||||
def audios(self) -> Sequence[AudioItem]:
|
||||
return self.get("audio", [])
|
||||
|
||||
def get_item_counts(self) -> Mapping[str, int]:
|
||||
return {m: len(items) for m, items in self.items()}
|
||||
|
||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.images[item_idx]
|
||||
|
||||
@ -612,6 +618,12 @@ class BaseMultiModalProcessor(ABC):
|
||||
def _get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.ctx.tokenizer
|
||||
|
||||
def _get_mm_items(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> MultiModalDataItems:
|
||||
return MultiModalDataItems.from_dict(mm_data)
|
||||
|
||||
@abstractmethod
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
@ -778,7 +790,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
3. Extract information about the placeholder tokens from the
|
||||
processed token IDs.
|
||||
"""
|
||||
mm_items = MultiModalDataItems.from_dict(mm_data)
|
||||
mm_items = self._get_mm_items(mm_data)
|
||||
|
||||
hf_inputs = self._apply_hf_processor(prompt_text, mm_items,
|
||||
mm_processor_kwargs)
|
||||
@ -791,7 +803,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
|
||||
# If HF processor already inserts placeholder tokens,
|
||||
# there is no need for us to insert them
|
||||
mm_item_counts = {m: len(items) for m, items in mm_items.items()}
|
||||
mm_item_counts = mm_items.get_item_counts()
|
||||
all_placeholders = self._find_placeholders(all_prompt_repls,
|
||||
prompt_ids, mm_item_counts)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user