[Model] merged input processor for Phi-3-Vision models (#10977)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Isotr0py 2024-12-10 04:55:10 +08:00 committed by GitHub
parent ca871491ed
commit a811dd6608
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 234 additions and 408 deletions

View File

@ -89,7 +89,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=772, total_tokens=782) completion_tokens=10, prompt_tokens=775, total_tokens=785)
message = choice.message message = choice.message
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
@ -181,7 +181,7 @@ async def test_single_chat_session_image_base64encoded(
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=772, total_tokens=782) completion_tokens=10, prompt_tokens=775, total_tokens=785)
message = choice.message message = choice.message
message = chat_completion.choices[0].message message = chat_completion.choices[0].message

View File

@ -95,5 +95,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
assert len(embeddings["data"]) == 1 assert len(embeddings["data"]) == 1
assert len(embeddings["data"][0]["embedding"]) == 3072 assert len(embeddings["data"][0]["embedding"]) == 3072
assert embeddings["usage"]["completion_tokens"] == 0 assert embeddings["usage"]["completion_tokens"] == 0
assert embeddings["usage"]["prompt_tokens"] == 762 assert embeddings["usage"]["prompt_tokens"] == 765
assert embeddings["usage"]["total_tokens"] == 762 assert embeddings["usage"]["total_tokens"] == 765

View File

@ -2,12 +2,10 @@
from typing import Optional from typing import Optional
import pytest import pytest
import torch from transformers import AutoTokenizer
from transformers import AutoImageProcessor, AutoTokenizer
from vllm.inputs import InputContext, token_inputs from vllm.inputs import InputContext, InputProcessingContext
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from vllm.multimodal import MultiModalRegistry
from .....conftest import _ImageAssets from .....conftest import _ImageAssets
from ....utils import build_model_context from ....utils import build_model_context
@ -17,15 +15,9 @@ models = ["microsoft/Phi-3.5-vision-instruct"]
# Wrap lazy imports to avoid initializing CUDA during test collection # Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture() @pytest.fixture()
def input_processor_for_phi3v(): def processor_for_phi3v():
from vllm.model_executor.models.phi3v import input_processor_for_phi3v from vllm.model_executor.models.phi3v import Phi3VProcessor
return input_processor_for_phi3v return Phi3VProcessor
@pytest.fixture()
def dummy_data_for_phi3v():
from vllm.model_executor.models.phi3v import dummy_data_for_phi3v
return dummy_data_for_phi3v
@pytest.fixture() @pytest.fixture()
@ -34,53 +26,6 @@ def get_max_phi3v_image_tokens():
return get_max_phi3v_image_tokens return get_max_phi3v_image_tokens
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops", [4, 16, None])
def test_input_mapper_override(model: str, image_assets: _ImageAssets,
num_crops: Optional[int]):
"""Ensure that the [default] input mapper handles num_crops properly."""
# We pass the processor kwargs here since for this model, we fall back to
# the default mapper; this will fall back to the HF mapper and forward
# mm_processor_kwargs to it.
mm_processor_kwargs = {
"num_crops": num_crops
} if num_crops is not None else {}
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
)
hf_processor = AutoImageProcessor.from_pretrained(model,
trust_remote_code=True,
**mm_processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
image = image_assets[0].pil_image
hf_result = hf_processor.preprocess(
image,
return_tensors="pt",
)
vllm_result = mm_registry.map_input(
ctx.model_config,
{"image": image},
)
assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"])
assert torch.all(
hf_result["num_img_tokens"] == vllm_result["num_img_tokens"])
# For pixel values, the second axis should be the num_crops + 1
# for the rescaled original image. The default value in VLLM falls
# back to the HF config, which is why we compare to the processor num_crops
assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"])
assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_max_tokens", [ @pytest.mark.parametrize("num_crops,expected_max_tokens", [
(4, 781), (4, 781),
@ -112,48 +57,20 @@ def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [ @pytest.mark.parametrize(
(4, 781, 1), "num_crops,expected_toks_per_img,num_imgs",
(4, 781, 2), [
(16, 2653, 1), (4, 757, 1),
(16, 2653, 2), (4, 757, 2),
]) (16, 1921, 1),
def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int, (16, 1921, 2),
toks_per_img: int, num_imgs: int): # the default num_crops of phi-3.5-vision is 4
"""Ensure dummy_data_for_phi3v handles num_crops properly.""" (None, 757, 2),
# Same as the previous test - don't initialize mm_processor_kwargs (None, 757, 2),
# in this test and assume that the kwargs will be correctly expanded by ])
# the partial when calling the dummy data func. def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets,
ctx = build_model_context( model: str, num_crops: Optional[int],
model_name=model, expected_toks_per_img: int, num_imgs: int):
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
dummy_data = dummy_data_for_phi3v(
ctx=ctx,
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
mm_counts={"image": num_imgs},
num_crops=num_crops,
)
sequence_data = dummy_data.seq_data
# Ensure we have the right number of placeholders per num_crops size
img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID)
assert img_tok_count == toks_per_img * num_imgs
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [
(4, 757, 1),
(4, 757, 2),
(16, 1921, 1),
(16, 1921, 2),
])
def test_input_processor_override(input_processor_for_phi3v,
image_assets: _ImageAssets, model: str,
num_crops: int, expected_toks_per_img: int,
num_imgs: int):
"""Ensure input_processor_for_phi3v handles num_crops properly.""" """Ensure input_processor_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs # Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by # in this test and assume that the kwargs will be correctly expanded by
@ -163,19 +80,20 @@ def test_input_processor_override(input_processor_for_phi3v,
tokenizer_name=model, tokenizer_name=model,
trust_remote_code=True, trust_remote_code=True,
) )
tokenizer = AutoTokenizer.from_pretrained(model) tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass # Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
images = [image_assets[0].pil_image] * num_imgs images = [image_assets[0].pil_image] * num_imgs
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), mm_data = {"image": images}
prompt=prompt, mm_processor_kwargs = {}
multi_modal_data={"image": images}) if num_crops is not None:
mm_processor_kwargs = {"num_crops": num_crops}
processed_inputs = input_processor_for_phi3v(ctx, processor = processor_for_phi3v(ctx)
inputs, processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
num_crops=num_crops)
# Ensure we have the right number of placeholders per num_crops size # Ensure we have the right number of placeholders per num_crops size
img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)

View File

@ -15,13 +15,13 @@ from ..models.utils import build_model_context
# Used for fast tests where the model doesn't matter # Used for fast tests where the model doesn't matter
DUMMY_MODEL_ID = "facebook/opt-125m" DUMMY_MODEL_ID = "facebook/opt-125m"
# Used for tests that need a multimodal model # Used for tests that need a multimodal model
MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" MULTIMODAL_MODEL_ID = "OpenGVLab/InternVL2-2B"
# For mm_processor_kwargs - we test overrides by defining mocks for each place # For mm_processor_kwargs - we test overrides by defining mocks for each place
# it is used, and ensuring that we can pass processor kwargs an override value # it is used, and ensuring that we can pass processor kwargs an override value
# to receive the intended result for things like sequence length etc. # to receive the intended result for things like sequence length etc.
DEFAULT_NUM_CROPS = 4 DEFAULT_MAX_DYNAMIC_PATCH = 6
NUM_CROPS_OVERRIDE = 16 MAX_DYNAMIC_PATCH_OVERRIDE = 4
# Mocks for all of the places that we use the mm_processor_kwargs # Mocks for all of the places that we use the mm_processor_kwargs
@ -33,10 +33,11 @@ def use_processor_mock():
def custom_processor(ctx: InputContext, def custom_processor(ctx: InputContext,
inputs: DecoderOnlyInputs, inputs: DecoderOnlyInputs,
*, *,
num_crops=DEFAULT_NUM_CROPS): max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH):
# For testing purposes, we don't worry about the prompt # For testing purposes, we don't worry about the prompt
return token_inputs(prompt_token_ids=[], return token_inputs(
mm_processor_kwargs={"num_crops": num_crops}) prompt_token_ids=[],
mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch})
with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
return_value=custom_processor): return_value=custom_processor):
@ -52,9 +53,9 @@ def use_dummy_data_mock():
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
*, *,
num_crops=DEFAULT_NUM_CROPS): max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH):
seq_data = SequenceData( seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * max_dynamic_patch))
return DummyData(seq_data, None) return DummyData(seq_data, None)
with patch( with patch(
@ -65,15 +66,15 @@ def use_dummy_data_mock():
# Lazy import to avoid CUDA reinitialization error # Lazy import to avoid CUDA reinitialization error
def mm_model_cls(): def mm_model_cls():
from vllm.model_executor.models.phi3v import Phi3VForCausalLM from vllm.model_executor.models.internvl import InternVLChatModel
return Phi3VForCausalLM return InternVLChatModel
# lambda whose signature matches max token calcs extra & mapper + extra kwargs # lambda whose signature matches max token calcs extra & mapper + extra kwargs
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops get_max_dynamic_patch = lambda ctx, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: max_dynamic_patch # noqa: E501
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { custom_mapper = lambda ctx, data, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: { # noqa: E501
"pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) "pixel_values": torch.zeros(size=(1, max_dynamic_patch + 1, 3, 448, 448))
} }
@ -88,27 +89,28 @@ def test_default_processor_is_a_noop():
assert proc_inputs is proc_outputs assert proc_inputs is proc_outputs
def _get_num_crops_info(init_num_crops: int, inference_num_crops: int): def _get_max_dynamic_patch_info(init_max_dynamic_patch: int,
"""Get the init / inference kwargs and expected num_crops for this test.""" inference_max_dynamic_patch: int):
# If we have a value for num_crops, pass the override value and make """Get the init / inference kwargs and expected max_dynamic_patch."""
# If we have a value for max_dynamic_patch, pass the override value and make
# sure we get that value as a return-value from out mock processor, # sure we get that value as a return-value from out mock processor,
# otherwise fall back to the default value # otherwise fall back to the default value
init_kwargs = None if init_num_crops is None else { init_kwargs = None if init_max_dynamic_patch is None else {
"num_crops": init_num_crops "max_dynamic_patch": init_max_dynamic_patch
} }
inference_kwargs = None if inference_num_crops is None else { inference_kwargs = None if inference_max_dynamic_patch is None else {
"num_crops": inference_num_crops "max_dynamic_patch": inference_max_dynamic_patch
} }
if inference_num_crops is not None: if inference_max_dynamic_patch is not None:
expected_seq_count = inference_num_crops expected_seq_count = inference_max_dynamic_patch
elif init_num_crops is not None: elif init_max_dynamic_patch is not None:
expected_seq_count = init_num_crops expected_seq_count = init_max_dynamic_patch
else: else:
expected_seq_count = DEFAULT_NUM_CROPS expected_seq_count = DEFAULT_MAX_DYNAMIC_PATCH
return init_kwargs, inference_kwargs, expected_seq_count return init_kwargs, inference_kwargs, expected_seq_count
def _get_processed_num_crops( def _get_processed_max_dynamic_patch(
processor: Callable[[ProcessorInputs], ProcessorInputs], processor: Callable[[ProcessorInputs], ProcessorInputs],
inference_kwargs: Optional[Dict[str, int]], inference_kwargs: Optional[Dict[str, int]],
) -> int: ) -> int:
@ -120,27 +122,30 @@ def _get_processed_num_crops(
assert "type" in processed_inputs assert "type" in processed_inputs
assert processed_inputs["type"] == "token" assert processed_inputs["type"] == "token"
assert "mm_processor_kwargs" in processed_inputs assert "mm_processor_kwargs" in processed_inputs
return processed_inputs["mm_processor_kwargs"]["num_crops"] return processed_inputs["mm_processor_kwargs"]["max_dynamic_patch"]
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ @pytest.mark.parametrize(
(None, None), "init_max_dynamic_patch,inference_max_dynamic_patch", [
(NUM_CROPS_OVERRIDE, None), (None, None),
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), (MAX_DYNAMIC_PATCH_OVERRIDE, None),
]) (DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE),
def test_input_processor_kwargs(use_processor_mock, init_num_crops, ])
inference_num_crops): def test_input_processor_kwargs(use_processor_mock, init_max_dynamic_patch,
inference_max_dynamic_patch):
"""Ensure input processors can use processor kwargs.""" """Ensure input processors can use processor kwargs."""
dummy_registry = InputRegistry() dummy_registry = InputRegistry()
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info( (init_kwargs, inference_kwargs,
init_num_crops, inference_num_crops) expected_seq_count) = _get_max_dynamic_patch_info(
init_max_dynamic_patch, inference_max_dynamic_patch)
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs) ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config) processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = _get_processed_num_crops(processor, inference_kwargs) max_dynamic_patch_val = _get_processed_max_dynamic_patch(
processor, inference_kwargs)
assert num_crops_val == expected_seq_count assert max_dynamic_patch_val == expected_seq_count
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -165,18 +170,21 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
processor = dummy_registry.create_input_processor(ctx.model_config) processor = dummy_registry.create_input_processor(ctx.model_config)
# Should filter out the inference time kwargs # Should filter out the inference time kwargs
num_crops_val = _get_processed_num_crops(processor, mm_processor_kwargs) max_dynamic_patch_val = _get_processed_max_dynamic_patch(
assert num_crops_val == DEFAULT_NUM_CROPS processor, mm_processor_kwargs)
assert max_dynamic_patch_val == DEFAULT_MAX_DYNAMIC_PATCH
### Test overrides for the dummy data ### Test overrides for the dummy data
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) @pytest.mark.parametrize("max_dynamic_patch",
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): [None, MAX_DYNAMIC_PATCH_OVERRIDE])
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, max_dynamic_patch):
"""Ensure dummy data factories can use processor kwargs.""" """Ensure dummy data factories can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else { mm_processor_kwargs = None if max_dynamic_patch is None else {
"num_crops": num_crops "max_dynamic_patch": max_dynamic_patch
} }
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH
if max_dynamic_patch is None else max_dynamic_patch)
dummy_registry = InputRegistry() dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID, ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs) mm_processor_kwargs=mm_processor_kwargs)
@ -217,17 +225,20 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
# len is solely dependent on the value of the mm_processor_kwargs. # len is solely dependent on the value of the mm_processor_kwargs.
dummy_data = dummy_registry.dummy_data_for_profiling( dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry) ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS assert len(
dummy_data.seq_data.prompt_token_ids) == DEFAULT_MAX_DYNAMIC_PATCH
### Test overrides for the max token count per multimodal instance ### Test overrides for the max token count per multimodal instance
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) @pytest.mark.parametrize("max_dynamic_patch",
def test_max_tokens_kwarg_overrides(num_crops): [None, MAX_DYNAMIC_PATCH_OVERRIDE])
def test_max_tokens_kwarg_overrides(max_dynamic_patch):
"""Ensure max token calcs can use processor kwargs.""" """Ensure max token calcs can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else { mm_processor_kwargs = None if max_dynamic_patch is None else {
"num_crops": num_crops "max_dynamic_patch": max_dynamic_patch
} }
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH
if max_dynamic_patch is None else max_dynamic_patch)
ctx = build_model_context(MULTIMODAL_MODEL_ID, ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate", task="generate",
@ -239,11 +250,11 @@ def test_max_tokens_kwarg_overrides(num_crops):
mm_registry.init_mm_limits_per_prompt(ctx.model_config) mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible # Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos # with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs. # our max_dynamic_patch value back from the mm_processor_kwargs.
with patch.object( with patch.object(
mm_registry._get_plugin("image"), mm_registry._get_plugin("image"),
"_max_mm_tokens", "_max_mm_tokens",
{mm_model_cls(): get_num_crops}, {mm_model_cls(): get_max_dynamic_patch},
): ):
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
ctx.model_config) ctx.model_config)
@ -279,26 +290,29 @@ def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
with patch.object( with patch.object(
mm_registry._get_plugin("image"), mm_registry._get_plugin("image"),
"_max_mm_tokens", "_max_mm_tokens",
{mm_model_cls(): get_num_crops}, {mm_model_cls(): get_max_dynamic_patch},
): ):
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
ctx.model_config) ctx.model_config)
assert max_multimodal_tokens == DEFAULT_NUM_CROPS assert max_multimodal_tokens == DEFAULT_MAX_DYNAMIC_PATCH
### Test overrides for the mapper ### Test overrides for the mapper
@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE]) @pytest.mark.parametrize(
def test_default_mapper_with_processor_kwargs(image_assets, num_crops): "max_dynamic_patch",
[DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE])
def test_default_mapper_with_processor_kwargs(image_assets, max_dynamic_patch):
"""Ensure that the mapper processor kwargs can fall back to HF models.""" """Ensure that the mapper processor kwargs can fall back to HF models."""
# NOTE - we don't validate bad inputs for the default mapper, because it's # NOTE - we don't validate bad inputs for the default mapper, because it's
# through the automodel interface in transformers, so we can't easily # through the automodel interface in transformers, so we can't easily
# inspect what kwargs are or are not allowed. # inspect what kwargs are or are not allowed.
ctx = build_model_context(MULTIMODAL_MODEL_ID, ctx = build_model_context(
task="generate", MULTIMODAL_MODEL_ID,
trust_remote_code=True, task="generate",
mm_processor_kwargs={"num_crops": num_crops}, trust_remote_code=True,
limit_mm_per_prompt={"image": 1}) mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch},
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry() mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config) mm_registry.init_mm_limits_per_prompt(ctx.model_config)
@ -307,20 +321,22 @@ def test_default_mapper_with_processor_kwargs(image_assets, num_crops):
mm_inputs = {"image": image} mm_inputs = {"image": image}
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
# Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336] # pixel vals should have shape: [batch, max_dynamic_patch+1, ...]
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 assert mapped_inputs["pixel_values"].shape[1] == max_dynamic_patch + 1
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ @pytest.mark.parametrize(
(None, None), "init_max_dynamic_patch,inference_max_dynamic_patch", [
(NUM_CROPS_OVERRIDE, None), (None, None),
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), (MAX_DYNAMIC_PATCH_OVERRIDE, None),
]) (DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE),
def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops, ])
inference_num_crops): def test_custom_mapper_kwarg_overrides(image_assets, init_max_dynamic_patch,
inference_max_dynamic_patch):
"""Ensure custom mappers can use processor kwargs.""" """Ensure custom mappers can use processor kwargs."""
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info( (init_kwargs, inference_kwargs,
init_num_crops, inference_num_crops) expected_seq_count) = _get_max_dynamic_patch_info(
init_max_dynamic_patch, inference_max_dynamic_patch)
ctx = build_model_context(MULTIMODAL_MODEL_ID, ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate", task="generate",
@ -335,7 +351,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
# Patch the image registry for phi3v with our lambda that is compatible # Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos # with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs. # our max_dynamic_patch value back from the mm_processor_kwargs.
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)( mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
mm_model_cls()) mm_model_cls())
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs, mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
@ -373,11 +389,12 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
# Patch the image registry for phi3v with our lambda that is compatible # Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos # with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs. # our max_dynamic_patch value back from the mm_processor_kwargs.
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)( mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
mm_model_cls()) mm_model_cls())
# Should filter out the inference time kwargs # Should filter out the inference time kwargs
mapped_inputs = mm_registry.map_input( mapped_inputs = mm_registry.map_input(
ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs) ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)
assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 assert mapped_inputs["pixel_values"].shape[1] == (
DEFAULT_MAX_DYNAMIC_PATCH + 1)

View File

@ -69,12 +69,12 @@ class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs.""" """The tokenizer used to tokenize the inputs."""
def get_hf_processor(self) -> ProcessorMixin: def get_hf_processor(self, **kwargs) -> ProcessorMixin:
return cached_get_processor( return cached_get_processor(
self.model_config.tokenizer, self.model_config.tokenizer,
tokenizer=self.tokenizer, # Override the tokenizer with ours tokenizer=self.tokenizer, # Override the tokenizer with ours
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) **kwargs)
N = TypeVar("N", bound=Type[nn.Module]) N = TypeVar("N", bound=Type[nn.Module])

View File

@ -12,22 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools from functools import cached_property
import re from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
from functools import cached_property, lru_cache TypedDict, Union)
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
from transformers import CLIPVisionConfig, PretrainedConfig ProcessorMixin)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import ModelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import InputContext
InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -36,12 +32,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ModalityProcessingMetadata,
MultiModalDataDict,
MultiModalProcessingMetadata,
PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
@ -303,231 +305,99 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline return image_features_hd_newline
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
target_height = int(np.ceil(height / padding_unit) * padding_unit)
top_padding = int((target_height - height) / 2)
bottom_padding = target_height - height - top_padding
padded_width = width
padded_height = height + top_padding + bottom_padding
return padded_width, padded_height
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int):
transposed = False
if width < height:
width, height = height, width
transposed = True
ratio = width / height
scale = 1
while scale * np.ceil(scale / ratio) <= hd_num:
scale += 1
scale -= 1
new_width = int(scale * 336)
new_height = int(new_width / ratio)
padded_width, padded_height = _calc_padded_size(width=new_width,
height=new_height)
if transposed:
padded_width, padded_height = padded_height, padded_width
return padded_width, padded_height
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
def get_phi3v_image_feature_size(
hf_config: Dict[str, Any],
*,
input_height: int,
input_width: int,
num_crops: int,
) -> int:
if num_crops is None:
num_crops = hf_config.get("num_crops", 16)
new_width, new_height = _calc_hd_transform_size(width=input_width,
height=input_height,
hd_num=num_crops)
return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
+ (new_height // 336 + 1) * 12
def get_max_phi3v_image_tokens(ctx: InputContext, def get_max_phi3v_image_tokens(ctx: InputContext,
*, *,
num_crops: Optional[int] = None): num_crops: Optional[int] = None):
mm_processor_kwargs = {}
if num_crops is not None:
mm_processor_kwargs["num_crops"] = num_crops
return get_phi3v_image_feature_size( model_config = ctx.model_config
ctx.get_hf_image_processor_config(), image_processor = cached_get_image_processor(
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, model_config.model,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, trust_remote_code=model_config.trust_remote_code,
num_crops=num_crops, **mm_processor_kwargs,
) )
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
return num_tokens
def dummy_data_for_phi3v(ctx: InputContext,
seq_len: int, def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int]):
*,
num_crops: Optional[int] = None):
num_images = mm_counts["image"] num_images = mm_counts["image"]
image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) data = dummy_image_for_clip(
seq_data, ranges = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
num_images,
image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG, CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images, num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
) )
return DummyData(seq_data, mm_data, ranges) hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
return MultiModalKwargs(**hf_inputs)
@lru_cache def create_metadata_for_phi3v(
def _get_image_placeholder_token_id_candidates( ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
model_config: ModelConfig, return {
idx: int, "image":
) -> List[List[int]]: ModalityProcessingMetadata(prompt_repls=[
assert idx > 0 PromptReplacement(target=[_IMAGE_TOKEN_ID],
repl_unit=[_IMAGE_TOKEN_ID],
tokenizer = cached_get_tokenizer(model_config.tokenizer) repl_count=get_max_phi3v_image_tokens(ctx)),
]),
# This is used when the image token is at the start of the string }
start_candidate = tokenizer.encode(f"<|image_{idx}|>",
add_special_tokens=False)
# This is used when the image token is in the middle of the string
# We need to get the token for "<", not "▁<"
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
a_token_id, = tokenizer.encode("a", add_special_tokens=False)
a_token_id_, *middle_candidate = tokenizer.encode(f"a<|image_{idx}|>",
add_special_tokens=False)
assert a_token_id == a_token_id_
return [start_candidate, middle_candidate]
def input_processor_for_phi3v(ctx: InputContext, class Phi3VProcessor(BaseMultiModalProcessor):
inputs: DecoderOnlyInputs,
*,
num_crops: Optional[int] = None):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config def __init__(self, ctx: InputProcessingContext) -> None:
hf_config = ctx.get_hf_image_processor_config() super().__init__(
ctx=ctx,
metadata=create_metadata_for_phi3v(ctx),
)
image_data = multi_modal_data["image"] def _get_hf_processor(
if isinstance(image_data, Image.Image): self,
w, h = image_data.size *,
image_feature_size = [ num_crops: Optional[int] = None,
get_phi3v_image_feature_size(hf_config, ) -> ProcessorMixin:
input_width=w, if num_crops is not None:
input_height=h, return self.ctx.get_hf_processor(num_crops=num_crops)
num_crops=num_crops) return self.ctx.get_hf_processor()
]
image_data = [image_data]
elif is_list_of(image_data, Image.Image):
image_feature_size = []
for image in image_data:
w, h = image.size
image_feature_size.append(
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h,
num_crops=num_crops))
elif isinstance(image_data, torch.Tensor):
image_feature_size = [image_data.shape[0]]
image_data = [image_data]
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[0] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
prompt = inputs.get("prompt") def _apply_hf_processor(
if prompt is None: self,
# for async server request, we assume prompt and its token_ids is always prompt: str,
# in correct format. And num_image_tags == len(image_data) always True. mm_data: MultiModalDataDict,
image_idx = range(1, len(image_data) + 1) mm_processor_kwargs: Mapping[str, object],
new_prompt = None ) -> BatchFeature:
else: processed_outputs = super()._apply_hf_processor(
image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt))) prompt, mm_data, mm_processor_kwargs)
if prompt.count("<|image|>") > 0: # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
logger.warning("Please follow the prompt format that is " # which will cause OverflowError when decoding the prompt_ids.
"documented on HuggingFace which does not involve " # Therefore, we need to do an early replacement here
"repeating <|image|> tokens.") token_ids = processed_outputs['input_ids']
elif (num_image_tags := len(image_idx)) > 1: token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
assert num_image_tags == len( processed_outputs['input_ids'] = token_ids
image_data), "The count of image_placeholder not match image's" return processed_outputs
new_prompt = prompt
prompt_token_ids = inputs["prompt_token_ids"].copy() def _get_dummy_mm_kwargs(
self,
# masked placeholder with image token id mm_counts: Mapping[str, int],
for idx in image_idx: ) -> MultiModalKwargs:
candidates = _get_image_placeholder_token_id_candidates(model_config, return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts)
idx=idx)
for candidate in candidates:
for i in range(len(prompt_token_ids) - len(candidate) + 1):
if prompt_token_ids[i:i + len(candidate)] == candidate:
prompt_token_ids[i:i +
len(candidate)] = ([_IMAGE_TOKEN_ID] *
len(candidate))
break
# merge consecutive tag ids
merged_token_ids: List[int] = []
for is_placeholder, token_ids in itertools.groupby(
prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID):
if is_placeholder:
merged_token_ids.append(_IMAGE_TOKEN_ID)
else:
merged_token_ids.extend(list(token_ids))
# TODO: Move this to utils or integrate with clip.
new_token_ids: List[int] = []
placeholder_ranges: List[PlaceholderRange] = []
placeholder_idx = 0
while merged_token_ids:
token_id = merged_token_ids.pop(0)
if token_id == _IMAGE_TOKEN_ID:
replacement_ids = repeat_and_pad_token(
_IMAGE_TOKEN_ID,
repeat_count=image_feature_size[placeholder_idx],
)
placeholder_ranges.append({
"offset": len(new_token_ids),
"length": len(replacement_ids)
})
new_token_ids.extend(replacement_ids)
placeholder_idx += 1
else:
new_token_ids.append(token_id)
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) @MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -3,7 +3,8 @@ from abc import ABC, abstractmethod
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union from typing import (Any, Dict, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast)
import torch import torch
from transformers import BatchFeature, ProcessorMixin from transformers import BatchFeature, ProcessorMixin
@ -11,7 +12,8 @@ from typing_extensions import TypeAlias, TypedDict
from vllm.inputs import DummyData, InputProcessingContext from vllm.inputs import DummyData, InputProcessingContext
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of from vllm.utils import (flatten_2d_lists, full_groupby, is_list_of,
resolve_mm_processor_kwargs)
from .inputs import (AudioItem, ImageItem, MultiModalDataDict, from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
@ -543,8 +545,14 @@ class BaseMultiModalProcessor(ABC):
self.ctx = ctx self.ctx = ctx
self.metadata = metadata self.metadata = metadata
self.init_mm_processor_kwargs = (ctx.model_config.mm_processor_kwargs
or {})
def _get_hf_processor(self) -> ProcessorMixin: def _get_hf_processor(
self,
**mm_processor_kwargs: Mapping[str, object],
) -> ProcessorMixin:
# by default, we won't pass any kwargs to the processor initialization
return self.ctx.get_hf_processor() return self.ctx.get_hf_processor()
def _get_tokenizer(self) -> AnyTokenizer: def _get_tokenizer(self) -> AnyTokenizer:
@ -581,7 +589,13 @@ class BaseMultiModalProcessor(ABC):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object], mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
hf_processor = self._get_hf_processor() # some mm_processor_kwargs may be used in processor initialization
# instead of processor call
processor_init_kwargs = {
**self.init_mm_processor_kwargs,
**mm_processor_kwargs,
}
hf_processor = self._get_hf_processor(**processor_init_kwargs)
processor_data = dict[str, Any]() processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]() passthrough_data = dict[str, Any]()
@ -601,6 +615,13 @@ class BaseMultiModalProcessor(ABC):
else: else:
processor_data[k] = v processor_data[k] = v
# filter mm_processor_kwargs used in processor call
mm_processor_kwargs = resolve_mm_processor_kwargs(
self.init_mm_processor_kwargs,
cast(Dict[str, Any], mm_processor_kwargs),
hf_processor,
)
try: try:
hf_inputs = hf_processor( hf_inputs = hf_processor(
text=prompt, # type: ignore text=prompt, # type: ignore