[Bugfix] Make image processor respect mm_processor_kwargs for Qwen2-VL (#10112)

Signed-off-by: Jiahao Li <liplus17@163.com>
This commit is contained in:
Jiahao Li 2024-11-07 18:50:44 +08:00 committed by GitHub
parent a6f332d0d9
commit 999df95b4e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,8 +22,8 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import partial
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Tuple, Type, TypedDict, Union)
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, Type, TypedDict, Union)
import torch
import torch.nn as nn
@ -558,6 +558,17 @@ class Qwen2VisionTransformer(nn.Module):
# === Vision input helpers === #
def get_mm_processor_kwargs(
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None) -> Dict[str, int]:
mm_processor_kwargs = {}
if min_pixels:
mm_processor_kwargs["min_pixels"] = min_pixels
if max_pixels:
mm_processor_kwargs["max_pixels"] = max_pixels
return mm_processor_kwargs
def mm_input_mapper_for_qwen2_vl(
ctx: InputContext,
data: MultiModalData[object],
@ -575,12 +586,8 @@ def mm_input_mapper_for_qwen2_vl(
model_config = ctx.model_config
# Handle mm processor kwargs; we pass these at creation time
# because preprocess() in transformers doesn't expose them
mm_processor_kwargs = {}
if min_pixels:
mm_processor_kwargs["min_pixels"] = min_pixels
if max_pixels:
mm_processor_kwargs["max_pixels"] = max_pixels
mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels,
max_pixels=max_pixels)
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
@ -683,7 +690,10 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
*,
min_pixels=None,
max_pixels=None) -> int:
image_processor = cached_get_image_processor(ctx.model_config.model)
mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels,
max_pixels=max_pixels)
image_processor = cached_get_image_processor(ctx.model_config.model,
**mm_processor_kwargs)
max_resized_height, max_resized_width, max_llm_image_tokens = \
_get_max_image_info(image_processor, data_type_key=data_type_key,
mm_count=1, min_pixels=min_pixels,
@ -705,7 +715,10 @@ def dummy_data_for_qwen2_vl(
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
image_processor = cached_get_image_processor(ctx.model_config.model)
mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels,
max_pixels=max_pixels)
image_processor = cached_get_image_processor(ctx.model_config.model,
**mm_processor_kwargs)
num_images = mm_counts["image"]
max_resized_height, max_resized_width, max_llm_image_tokens = \