[Model] Refactor Qwen2-VL to use merged multimodal processor (#11258)

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Isotr0py 2024-12-20 00:28:00 +08:00 committed by GitHub
parent 7379b3d4b2
commit e24113a8fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 272 additions and 522 deletions

View File

@ -447,7 +447,6 @@ def run_qwen_vl(question: str, modality: str):
# Qwen2-VL # Qwen2-VL
def run_qwen2_vl(question: str, modality: str): def run_qwen2_vl(question: str, modality: str):
assert modality == "image"
model_name = "Qwen/Qwen2-VL-7B-Instruct" model_name = "Qwen/Qwen2-VL-7B-Instruct"
@ -463,8 +462,13 @@ def run_qwen2_vl(question: str, modality: str):
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n" f"{question}<|im_end|>\n"
"<|im_start|>assistant\n") "<|im_start|>assistant\n")
stop_token_ids = None stop_token_ids = None

View File

@ -1,12 +1,9 @@
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
import pytest import pytest
import torch
from PIL.Image import Image
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.inputs import InputContext, token_inputs from vllm.inputs import InputContext, InputProcessingContext
from vllm.multimodal import MultiModalRegistry
from .....conftest import _ImageAssets from .....conftest import _ImageAssets
from ....utils import build_model_context from ....utils import build_model_context
@ -20,22 +17,9 @@ MAX_PIXELS = "max_pixels"
# NOTE: Qwen2VL supports multiple input modalities, so it registers multiple # NOTE: Qwen2VL supports multiple input modalities, so it registers multiple
# input mappers. # input mappers.
@pytest.fixture() @pytest.fixture()
def image_input_mapper_for_qwen2_vl(): def processor_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import ( from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
image_input_mapper_for_qwen2_vl) return Qwen2VLMultiModalProcessor
return image_input_mapper_for_qwen2_vl
@pytest.fixture()
def input_processor_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import (
input_processor_for_qwen2_vl)
return input_processor_for_qwen2_vl
@pytest.fixture()
def qwen2_vl_context() -> InputContext:
return build_model_context(model_name=MODEL)
@pytest.fixture() @pytest.fixture()
@ -45,12 +29,6 @@ def get_max_qwen2_vl_image_tokens():
return get_max_qwen2_vl_image_tokens return get_max_qwen2_vl_image_tokens
@pytest.fixture()
def dummy_data_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import dummy_data_for_qwen2_vl
return dummy_data_for_qwen2_vl
@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [ @pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
({}, 1225), ({}, 1225),
({ ({
@ -58,110 +36,70 @@ def dummy_data_for_qwen2_vl():
MAX_PIXELS: 512**2 MAX_PIXELS: 512**2
}, 324), }, 324),
]) ])
def test_qwen2_vl_max_image_tokens(get_max_qwen2_vl_image_tokens, @pytest.mark.parametrize("model", [MODEL])
qwen2_vl_context: InputContext, def test_qwen2_vl_max_image_tokens(
mm_processor_kwargs: Dict[str, Any], get_max_qwen2_vl_image_tokens,
expected_max_tokens: int): model: str,
mm_processor_kwargs: Dict[str, Any],
expected_max_tokens: int,
):
"""Ensure that the max token calc handles min/max pixels properly.""" """Ensure that the max token calc handles min/max pixels properly."""
actual_max_tokens = get_max_qwen2_vl_image_tokens(qwen2_vl_context, ctx = build_model_context(
**mm_processor_kwargs) model_name=model,
tokenizer_name=model,
mm_processor_kwargs=None,
)
actual_max_tokens = get_max_qwen2_vl_image_tokens(
InputContext(ctx.model_config), **mm_processor_kwargs)
assert actual_max_tokens == expected_max_tokens assert actual_max_tokens == expected_max_tokens
@pytest.mark.parametrize("mm_processor_kwargs,token_count,img_size", [ @pytest.mark.parametrize(
[{}, 1225, (980, 980)], "mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [
[{ ({}, 1426, (5704, 1176)),
MIN_PIXELS: 64**2, ({
MAX_PIXELS: 512**2 MIN_PIXELS: 64**2,
}, 324, (504, 504)], MAX_PIXELS: 512**2
]) }, 330, (1320, 1176)),
def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl, ])
qwen2_vl_context: InputContext, @pytest.mark.parametrize("model", [MODEL])
mm_processor_kwargs: Dict[str, Any], @pytest.mark.parametrize("num_imgs", [1, 2])
token_count: int, img_size: Tuple[int, int]): def test_processor_override(
"""Ensure that the dummy data handles min/max pixels properly.""" processor_for_qwen2_vl,
seq_len = 3000 image_assets: _ImageAssets,
hf_config = qwen2_vl_context.get_hf_config() model: str,
image_token_id = hf_config.image_token_id mm_processor_kwargs: Dict[str, Any],
expected_toks_per_img: int,
# NOTE: video value is required, but isn't actually used expected_pixels_shape: Tuple[int, int],
# when making the dummy data except for error handling currently num_imgs: int,
dummy_data = dummy_data_for_qwen2_vl( ):
ctx=qwen2_vl_context, """Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly."""
seq_len=seq_len, # Same as the previous test - don't initialize mm_processor_kwargs
mm_counts={ # in this test and assume that the kwargs will be correctly expanded by
"image": 1, # the partial when calling the custom input processor.
"video": 0 ctx = build_model_context(
}, model_name=model,
**mm_processor_kwargs, tokenizer_name=model,
mm_processor_kwargs=None,
) )
seq_data = dummy_data.seq_data tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
mm_data = dummy_data.multi_modal_data ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
images = [image_assets[0].pil_image] * num_imgs
# Ensure we have the right number of placeholders for min/max pixel values mm_data = {"image": images}
assert seq_data.get_token_ids().count(image_token_id) == token_count
# Ensure the images were resized correctly processor = processor_for_qwen2_vl(ctx)
image = mm_data["image"] processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
assert isinstance(image, Image)
assert image.size == img_size
# Ensure we have the right number of placeholders per num_crops size
hf_processor = processor._get_hf_processor(**mm_processor_kwargs)
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape
@pytest.mark.parametrize("mm_processor_kwargs,num_placeholders", [ assert img_tok_count == expected_toks_per_img * num_imgs
({}, 1426), assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
({ assert pixel_shape[1] == expected_pixels_shape[1]
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 330),
])
def test_input_processor(input_processor_for_qwen2_vl,
qwen2_vl_context: InputContext,
image_assets: _ImageAssets, num_placeholders: int,
mm_processor_kwargs: Dict[str, Any]):
"""Ensure that the image processor handles min/max pixels properly."""
tokenizer = AutoTokenizer.from_pretrained(MODEL)
prompt = "<|vision_start|><|image_pad|><|vision_end|>"
image = image_assets[0].pil_image
hf_config = qwen2_vl_context.get_hf_config()
image_token_id = hf_config.image_token_id
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": [image]})
processed_inputs = input_processor_for_qwen2_vl(qwen2_vl_context, inputs,
**mm_processor_kwargs)
assert processed_inputs["prompt_token_ids"].count(
image_token_id) == num_placeholders
assert len(processed_inputs["multi_modal_data"]["image"]) == 1
@pytest.mark.parametrize("mm_processor_kwargs,pixels_shape", [
({}, [5704, 1176]),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, [1320, 1176]),
])
def test_image_mapper_override(qwen2_vl_context: InputContext,
image_assets: _ImageAssets,
mm_processor_kwargs: Dict[str, Any],
pixels_shape: Tuple[int, int]):
"""Ensure that the image mapper handles min/max pixels properly."""
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(qwen2_vl_context.model_config)
image = image_assets[0].pil_image
mapped_output = mm_registry.map_input(
qwen2_vl_context.model_config,
{"image": image},
mm_processor_kwargs=mm_processor_kwargs,
)
# Dimension 0 of pixel values should match the product of image_grid_thw
actual_pixels_shape = mapped_output["pixel_values"].shape
assert list(actual_pixels_shape) == pixels_shape
assert actual_pixels_shape[0] == torch.prod(
mapped_output["image_grid_thw"])

View File

@ -164,7 +164,9 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
self, self,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
audio_len = get_max_qwen2_audio_audio_tokens(self.ctx) feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
audio_count = mm_counts["audio"] audio_count = mm_counts["audio"]
audio = np.zeros(audio_len) audio = np.zeros(audio_len)

View File

@ -22,28 +22,26 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import cached_property, partial from functools import cached_property, partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
Optional, Set, Tuple, Type, TypedDict, Union) Tuple, Type, TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from PIL import Image from PIL import Image
from transformers.image_utils import (get_image_size, from transformers import BatchFeature
infer_channel_dimension_format, from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
to_numpy_array) Qwen2VLProcessor)
from transformers.models.qwen2_vl.configuration_qwen2_vl import ( from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLConfig, Qwen2VLVisionConfig) Qwen2VLConfig, Qwen2VLVisionConfig)
from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
make_batched_images, make_batched_videos, smart_resize)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import InputContext
InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.activation import QuickGELU
@ -56,14 +54,14 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import MultiModalDataDict, NestedTensors
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalKwargs, NestedTensors) MultiModalDataItems, ProcessorInputs,
from vllm.multimodal.utils import cached_get_tokenizer PromptReplacement)
from vllm.platforms import _Backend from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import is_list_of
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend, from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
@ -159,7 +157,7 @@ class Qwen2VisionMLP(nn.Module):
def __init__( def __init__(
self, self,
in_features: int, in_features: int,
hidden_features: int = None, hidden_features: int,
act_layer: Type[nn.Module] = QuickGELU, act_layer: Type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
@ -644,78 +642,8 @@ class Qwen2VisionTransformer(nn.Module):
# === Vision input helpers === # # === Vision input helpers === #
def get_mm_processor_kwargs(
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None) -> Dict[str, int]:
mm_processor_kwargs = {}
if min_pixels:
mm_processor_kwargs["min_pixels"] = min_pixels
if max_pixels:
mm_processor_kwargs["max_pixels"] = max_pixels
return mm_processor_kwargs
def mm_input_mapper_for_qwen2_vl(
ctx: InputContext,
data: MultiModalData[object],
data_type_key: str,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
) -> MultiModalKwargs:
"""Input mapper for Qwen2-VL."""
if data_type_key == "image" and isinstance(data, dict):
return MultiModalKwargs({
"image_embeds": data.get("image_embeds"),
"image_grid_thw": data.get("image_grid_thw"),
})
if data_type_key == "video" and isinstance(data, dict):
return MultiModalKwargs({
"video_embeds": data.get("video_embeds"),
"video_grid_thw": data.get("video_grid_thw"),
})
model_config = ctx.model_config
# Handle mm processor kwargs; we pass these at creation time
# because preprocess() in transformers doesn't expose them
mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels,
max_pixels=max_pixels)
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs,
)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
images = None
videos = None
if data_type_key == "image":
images = data
else:
assert data_type_key == "video"
videos = data
try:
batch_data = image_processor \
.preprocess(images=images, videos=videos, return_tensors="pt") \
.data
except Exception:
logger.error("Failed to process image (%s)", data)
raise
return MultiModalKwargs(batch_data)
image_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl,
data_type_key="image")
video_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl,
data_type_key="video")
def _get_vision_info( def _get_vision_info(
image_processor, vision_config: Qwen2VLVisionConfig,
height: int, height: int,
width: int, width: int,
min_pixels: int, min_pixels: int,
@ -726,12 +654,15 @@ def _get_vision_info(
): ):
"""Get information (resized height / width and number of vision tokens) """Get information (resized height / width and number of vision tokens)
of input image / video frame.""" of input image / video frame."""
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
if do_resize: if do_resize:
resized_height, resized_width = smart_resize( resized_height, resized_width = smart_resize(
height=height, height=height,
width=width, width=width,
factor=image_processor.patch_size * image_processor.merge_size, factor=patch_size * merge_size,
min_pixels=min_pixels, min_pixels=min_pixels,
max_pixels=max_pixels, max_pixels=max_pixels,
) )
@ -742,54 +673,41 @@ def _get_vision_info(
grid_t = mm_count grid_t = mm_count
else: else:
assert data_type_key == "video" assert data_type_key == "video"
grid_t = max(mm_count // image_processor.temporal_patch_size, 1) grid_t = max(mm_count // temporal_patch_size, 1)
grid_h = resized_height // image_processor.patch_size grid_h = resized_height // patch_size
grid_w = resized_width // image_processor.patch_size grid_w = resized_width // patch_size
vision_tokens = grid_t * grid_h * grid_w vision_tokens = grid_t * grid_h * grid_w
llm_num_vision_tokens = (vision_tokens // image_processor.merge_size // llm_num_vision_tokens = vision_tokens // (merge_size**2)
image_processor.merge_size)
return resized_height, resized_width, llm_num_vision_tokens return resized_height, resized_width, llm_num_vision_tokens
def _get_max_image_info( def _get_image_processor(hf_processor: Qwen2VLProcessor):
image_processor, image_processor = hf_processor.image_processor # type: ignore
data_type_key: str = "image", assert isinstance(image_processor, Qwen2VLImageProcessor)
mm_count: int = 1, return image_processor
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
):
# Limit min / max pixels unless they're explicitly provided
if min_pixels is None:
min_pixels = max(image_processor.min_pixels, 28 * 28)
if max_pixels is None:
max_pixels = min(image_processor.max_pixels, 1280 * 28 * 28)
return _get_vision_info(
image_processor,
height=9999999,
width=9999999,
min_pixels=min_pixels,
max_pixels=max_pixels,
data_type_key=data_type_key,
mm_count=mm_count,
)
def get_max_qwen2_vl_mm_tokens(ctx: InputContext, def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
data_type_key: str, data_type_key: str,
*, *,
min_pixels=None, min_pixels: Optional[int] = None,
max_pixels=None) -> int: max_pixels: Optional[int] = None) -> int:
mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels, hf_config = ctx.get_hf_config(Qwen2VLConfig)
max_pixels=max_pixels) vision_config = hf_config.vision_config
image_processor = cached_get_image_processor(ctx.model_config.model,
**mm_processor_kwargs) hf_processor = ctx.get_hf_processor(Qwen2VLProcessor)
max_resized_height, max_resized_width, max_llm_image_tokens = \ image_processor = _get_image_processor(hf_processor)
_get_max_image_info(image_processor, data_type_key=data_type_key,
mm_count=1, min_pixels=min_pixels, _, _, max_llm_image_tokens = _get_vision_info(
max_pixels=max_pixels) vision_config,
height=9999999,
width=9999999,
min_pixels=min_pixels or image_processor.min_pixels,
max_pixels=max_pixels or image_processor.max_pixels,
data_type_key=data_type_key,
)
return max_llm_image_tokens return max_llm_image_tokens
@ -799,290 +717,166 @@ get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
data_type_key="video") data_type_key="video")
def dummy_data_for_qwen2_vl( class Qwen2VLMultiModalDataItems(MultiModalDataItems):
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels,
max_pixels=max_pixels)
image_processor = cached_get_image_processor(ctx.model_config.model,
**mm_processor_kwargs)
num_images = mm_counts["image"] @staticmethod
max_resized_height, max_resized_width, max_llm_image_tokens = \ def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
_get_max_image_info(image_processor, data_type_key="image", """
mm_count=num_images, min_pixels=min_pixels, Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
max_pixels=max_pixels) """
if seq_len - max_llm_image_tokens - 2 < 0: multi_data = Qwen2VLMultiModalDataItems()
raise RuntimeError(
f"Qwen2-VL cannot process {num_images} images in a prompt, "
"please increase max_model_len or reduce image limit by "
"--limit-mm-per-prompt.")
# Check video counts. for k, v in data.items():
num_videos = mm_counts["video"] # TODO: Make a separate modality for embedding inputs
max_resized_height, max_resized_width, max_llm_video_tokens = \ # to avoid confusion
_get_max_image_info(image_processor, data_type_key="video", # yapf: disable
mm_count=num_videos, min_pixels=min_pixels, if k == "video":
max_pixels=max_pixels) # Special case since even a single item can be a list
if seq_len - max_llm_video_tokens - 2 < 0: multi_data[k] = ( # type: ignore[index]
raise RuntimeError( v if (isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
f"Qwen2-VL cannot process {num_videos} videos in a prompt, " or is_list_of(v, list)) else [v]
"please increase max_model_len or reduce video limit by " )
"--limit-mm-per-prompt.") elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (dict, torch.Tensor, list)) else [v]
)
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
hf_config = ctx.get_hf_config(Qwen2VLConfig) return multi_data
dummy_seqdata = SequenceData.from_prompt_token_counts( def get_item_counts(self) -> Mapping[str, int]:
(hf_config.vision_start_token_id, 1), return {
(hf_config.image_token_id, max_llm_image_tokens), m: (
(hf_config.vision_end_token_id, 1), len(items[f"{m}_grid_thw"]) # type: ignore
(0, seq_len - max_llm_image_tokens - 2), if isinstance(items, dict) else len(items))
) for m, items in self.items()
}
dummy_image = Image.new("RGB", (max_resized_width, max_resized_height),
color=0)
return DummyData(dummy_seqdata, {
"image":
dummy_image if num_images == 1 else [dummy_image] * num_images
})
def _get_llm_num_vision_tokens( class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
mm_inputs: list,
data_type_key: str,
image_processor,
min_pixels: int,
max_pixels: int,
):
"""Get number of vision tokens of multimodal inputs.
This method is derived from `transformers.models.qwen2_vl. def _get_mm_items(
image_processing_qwen2_vl.Qwen2VLImageProcessor._preprocess`. self,
""" mm_data: MultiModalDataDict,
image = to_numpy_array(mm_inputs[0]) ) -> MultiModalDataItems:
input_data_format = infer_channel_dimension_format(image) return Qwen2VLMultiModalDataItems.from_dict(mm_data)
height, width = get_image_size(image, channel_dim=input_data_format)
_, _, llm_num_vision_tokens = _get_vision_info( def _get_hf_processor(
image_processor, self,
height=height, *,
width=width, min_pixels: Optional[int] = None,
min_pixels=min_pixels, max_pixels: Optional[int] = None,
max_pixels=max_pixels, ) -> Qwen2VLProcessor:
do_resize=image_processor.do_resize, hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
data_type_key=data_type_key, image_processor = _get_image_processor(hf_processor)
mm_count=len(mm_inputs),
)
return llm_num_vision_tokens
if min_pixels:
image_processor.min_pixels = min_pixels
if max_pixels:
image_processor.max_pixels = max_pixels
if max_pixels or min_pixels:
image_processor.size = {
"min_pixels": image_processor.min_pixels,
"max_pixels": image_processor.max_pixels,
}
def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, return hf_processor
data_type_key: str, image_processor: Any,
prompt_token_ids: List[int], min_pixels: Optional[int],
max_pixels: Optional[int]) -> List[int]:
"""
Expand pad tokens for multi-modal inputs (e.g., images or videos).
Args: def _get_processor_data(
inputs (list): The multi-modal inputs (e.g., images or videos). self,
token_id (int): The token ID used to represent the multi-modal input. mm_items: MultiModalDataItems,
make_batched_fn (Callable): A function to batch the inputs. ) -> tuple[dict[str, Any], dict[str, Any]]:
data_type_key (str): The type of the multi-modal input. processor_data = dict[str, Any]()
image_processor (Any): The image processor used to process the inputs. passthrough_data = dict[str, Any]()
prompt_token_ids (List[int]): The list of token IDs in the prompt.
min_pixels (int): min pixels to used for img processing
max_pixels (int): max pixels to be used for img processing
Returns: for k, v in mm_items.items():
List[int]: The list of token IDs for the multi-modal inputs. # TODO: Make a separate modality for embedding inputs
""" # to avoid confusion
indices = [ if k in ("image", "video", "audio"):
idx for idx, token in enumerate(prompt_token_ids) if token == token_id if isinstance(v, dict):
] # Pass through embedding inputs (dict)
inputs = make_batched_fn(inputs) passthrough_data.update(v)
assert len(indices) == len(inputs) elif isinstance(v, torch.Tensor) and v.ndim == 3:
# Pass through embedding inputs (single)
passthrough_data[f"{k}_embeds"] = [v]
elif (is_list_of(v, torch.Tensor) and len(v) > 0
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
else:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
processor_data[k] = v
prompt_token_ids_with_data = [] return processor_data, passthrough_data
for cnt, data in enumerate(inputs):
num_tokens = _get_llm_num_vision_tokens( def _get_prompt_replacements(
[data] if data_type_key == "image" else data, self,
data_type_key=data_type_key, mm_items: MultiModalDataItems,
image_processor=image_processor, hf_inputs: BatchFeature,
min_pixels=min_pixels, mm_processor_kwargs: Mapping[str, object],
max_pixels=max_pixels, ) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered
placeholder = {
"image": hf_processor.image_token,
"video": hf_processor.video_token,
}
merge_length = image_processor.merge_size**2
def get_replacement_qwen2vl(item_idx: int, modality: str):
grid_thw = hf_inputs[f"{modality}_grid_thw"][item_idx]
num_tokens = grid_thw.prod() // merge_length
return placeholder[modality] * num_tokens
return [
PromptReplacement(
modality=modality,
target=placeholder[modality],
replacement=partial(get_replacement_qwen2vl,
modality=modality),
) for modality in ("image", "video")
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts["image"]
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token
image_processor = _get_image_processor(hf_processor)
data = {}
resized_height, resized_width = smart_resize(
height=9999999,
width=9999999,
factor=image_processor.patch_size * image_processor.merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
dummy_image = Image.new("RGB", (resized_width, resized_height),
color=0)
data["image"] = [dummy_image] * num_images
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
) )
if cnt == 0:
end_idx = indices[cnt]
non_data_tokens = prompt_token_ids[:end_idx]
else:
non_data_tokens = prompt_token_ids[indices[cnt - 1] +
1:indices[cnt]]
prompt_token_ids_with_data.extend(non_data_tokens)
prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens))
prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:])
return prompt_token_ids_with_data
def input_processor_for_qwen2_vl(
ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None:
return inputs
image_inputs = multi_modal_data.get("image", None)
video_inputs = multi_modal_data.get("video", None)
processor = cached_get_processor(ctx.model_config.model)
image_processor = processor.image_processor
# Apply processor kwarg overrides for image processor options
min_pixels = min_pixels if min_pixels else image_processor.min_pixels
max_pixels = max_pixels if max_pixels else image_processor.max_pixels
model_config = ctx.model_config
hf_config = ctx.get_hf_config(Qwen2VLConfig)
# To avoid redundant processing of vision objects (resize, rescale, etc.),
# we extract code of calculating number of vision tokens from
# `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`.
#
# The following code is equivalent to:
# prompt = inputs["prompt"]
# inputs = processor(text=[prompt],
# images=image_inputs,
# videos=video_inputs,
# padding=True,
# return_tensors="pt")
# prompt_token_ids = inputs["input_ids"][0].tolist()
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
prompt_token_ids = inputs["prompt_token_ids"]
# Expand image pad tokens.
if image_inputs is not None:
if isinstance(image_inputs, dict):
prompt_token_ids_with_image = []
image_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.image_token_id
]
# ensure all image tokens have grid_thw
assert \
len(image_indices) == image_inputs["image_grid_thw"].size(0), \
"image token num does not match image_grid_thw.shape"
image_counter = 0
pad_token_counter = 0
for idx, token in enumerate(prompt_token_ids):
if idx in image_indices:
grid_thw = image_inputs["image_grid_thw"][image_counter]
grid_t, grid_h, grid_w = grid_thw
num_pad_tokens = (grid_t * grid_h * grid_w //
image_processor.merge_size //
image_processor.merge_size)
prompt_token_ids_with_image.extend([token] *
num_pad_tokens)
image_counter += 1
pad_token_counter += num_pad_tokens
else:
prompt_token_ids_with_image.append(token)
# ensure all embeddings are used
assert \
pad_token_counter == image_inputs["image_embeds"].size(0), \
"image_embeds.shape does not match image_grid_thw"
prompt_token_ids = prompt_token_ids_with_image
else:
prompt_token_ids = _expand_pad_tokens(image_inputs,
hf_config.image_token_id,
make_batched_images,
"image",
image_processor,
prompt_token_ids,
min_pixels=min_pixels,
max_pixels=max_pixels)
if video_inputs is not None:
if isinstance(video_inputs, dict):
prompt_token_ids_with_video = []
video_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.video_token_id
]
# ensure all video tokens have grid_thw
assert \
len(video_indices) == video_inputs["video_grid_thw"].size(0), \
"video token num does not match video_grid_thw.shape"
video_counter = 0
pad_token_counter = 0
for idx, token in enumerate(prompt_token_ids):
if idx in video_indices:
grid_thw = video_inputs["video_grid_thw"][video_counter]
grid_t, grid_h, grid_w = grid_thw
num_pad_tokens = (grid_t * grid_h * grid_w //
image_processor.merge_size //
image_processor.merge_size)
prompt_token_ids_with_video.extend([token] *
num_pad_tokens)
video_counter += 1
pad_token_counter += num_pad_tokens
else:
prompt_token_ids_with_video.append(token)
# ensure all embeddings are used
assert \
pad_token_counter == video_inputs["video_embeds"].size(0), \
"video_embeds.shape does not match video_grid_thw"
prompt_token_ids = prompt_token_ids_with_video
else:
prompt_token_ids = _expand_pad_tokens(video_inputs,
hf_config.video_token_id,
make_batched_videos,
"video",
image_processor,
prompt_token_ids,
min_pixels=min_pixels,
max_pixels=max_pixels)
prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
return token_inputs(
prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data,
)
@MULTIMODAL_REGISTRY.register_image_input_mapper(
image_input_mapper_for_qwen2_vl)
@MULTIMODAL_REGISTRY.register_input_mapper("video",
video_input_mapper_for_qwen2_vl)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_qwen2_vl_video_tokens) "video", get_max_qwen2_vl_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP): SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
@ -1110,7 +904,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config: Qwen2VLConfig = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config

View File

@ -220,15 +220,18 @@ class MultiModalDataItems(UserDict[str, list[Any]]):
multi_data = MultiModalDataItems() multi_data = MultiModalDataItems()
for k, v in data.items(): for k, v in data.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
# yapf: disable # yapf: disable
if k == "video": if k == "video":
# Special case since even a single item can be a list # Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index] multi_data[k] = ( # type: ignore[index]
v if is_list_of(v, (list, torch.Tensor)) else [v] v if (isinstance(v, torch.Tensor)
or is_list_of(v, list)) else [v]
) )
elif k in ("image", "audio"): elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index] multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (list, torch.Tensor)) else [v] v if isinstance(v, (torch.Tensor, list)) else [v]
) )
else: else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
@ -252,6 +255,9 @@ class MultiModalDataItems(UserDict[str, list[Any]]):
def audios(self) -> Sequence[AudioItem]: def audios(self) -> Sequence[AudioItem]:
return self.get("audio", []) return self.get("audio", [])
def get_item_counts(self) -> Mapping[str, int]:
return {m: len(items) for m, items in self.items()}
def get_image_size(self, item_idx: int) -> ImageSize: def get_image_size(self, item_idx: int) -> ImageSize:
image = self.images[item_idx] image = self.images[item_idx]
@ -612,6 +618,12 @@ class BaseMultiModalProcessor(ABC):
def _get_tokenizer(self) -> AnyTokenizer: def _get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer return self.ctx.tokenizer
def _get_mm_items(
self,
mm_data: MultiModalDataDict,
) -> MultiModalDataItems:
return MultiModalDataItems.from_dict(mm_data)
@abstractmethod @abstractmethod
def _get_prompt_replacements( def _get_prompt_replacements(
self, self,
@ -778,7 +790,7 @@ class BaseMultiModalProcessor(ABC):
3. Extract information about the placeholder tokens from the 3. Extract information about the placeholder tokens from the
processed token IDs. processed token IDs.
""" """
mm_items = MultiModalDataItems.from_dict(mm_data) mm_items = self._get_mm_items(mm_data)
hf_inputs = self._apply_hf_processor(prompt_text, mm_items, hf_inputs = self._apply_hf_processor(prompt_text, mm_items,
mm_processor_kwargs) mm_processor_kwargs)
@ -791,7 +803,7 @@ class BaseMultiModalProcessor(ABC):
# If HF processor already inserts placeholder tokens, # If HF processor already inserts placeholder tokens,
# there is no need for us to insert them # there is no need for us to insert them
mm_item_counts = {m: len(items) for m, items in mm_items.items()} mm_item_counts = mm_items.get_item_counts()
all_placeholders = self._find_placeholders(all_prompt_repls, all_placeholders = self._find_placeholders(all_prompt_repls,
prompt_ids, mm_item_counts) prompt_ids, mm_item_counts)