[Model] merged input processor for Phi-3-Vision models (#10977)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
ca871491ed
commit
a811dd6608
@ -89,7 +89,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
|
|||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=772, total_tokens=782)
|
completion_tokens=10, prompt_tokens=775, total_tokens=785)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
@ -181,7 +181,7 @@ async def test_single_chat_session_image_base64encoded(
|
|||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=772, total_tokens=782)
|
completion_tokens=10, prompt_tokens=775, total_tokens=785)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
|
@ -95,5 +95,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
|
|||||||
assert len(embeddings["data"]) == 1
|
assert len(embeddings["data"]) == 1
|
||||||
assert len(embeddings["data"][0]["embedding"]) == 3072
|
assert len(embeddings["data"][0]["embedding"]) == 3072
|
||||||
assert embeddings["usage"]["completion_tokens"] == 0
|
assert embeddings["usage"]["completion_tokens"] == 0
|
||||||
assert embeddings["usage"]["prompt_tokens"] == 762
|
assert embeddings["usage"]["prompt_tokens"] == 765
|
||||||
assert embeddings["usage"]["total_tokens"] == 762
|
assert embeddings["usage"]["total_tokens"] == 765
|
||||||
|
@ -2,12 +2,10 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
from transformers import AutoTokenizer
|
||||||
from transformers import AutoImageProcessor, AutoTokenizer
|
|
||||||
|
|
||||||
from vllm.inputs import InputContext, token_inputs
|
from vllm.inputs import InputContext, InputProcessingContext
|
||||||
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
|
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
|
||||||
from vllm.multimodal import MultiModalRegistry
|
|
||||||
|
|
||||||
from .....conftest import _ImageAssets
|
from .....conftest import _ImageAssets
|
||||||
from ....utils import build_model_context
|
from ....utils import build_model_context
|
||||||
@ -17,15 +15,9 @@ models = ["microsoft/Phi-3.5-vision-instruct"]
|
|||||||
|
|
||||||
# Wrap lazy imports to avoid initializing CUDA during test collection
|
# Wrap lazy imports to avoid initializing CUDA during test collection
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def input_processor_for_phi3v():
|
def processor_for_phi3v():
|
||||||
from vllm.model_executor.models.phi3v import input_processor_for_phi3v
|
from vllm.model_executor.models.phi3v import Phi3VProcessor
|
||||||
return input_processor_for_phi3v
|
return Phi3VProcessor
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def dummy_data_for_phi3v():
|
|
||||||
from vllm.model_executor.models.phi3v import dummy_data_for_phi3v
|
|
||||||
return dummy_data_for_phi3v
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
@ -34,53 +26,6 @@ def get_max_phi3v_image_tokens():
|
|||||||
return get_max_phi3v_image_tokens
|
return get_max_phi3v_image_tokens
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", models)
|
|
||||||
@pytest.mark.parametrize("num_crops", [4, 16, None])
|
|
||||||
def test_input_mapper_override(model: str, image_assets: _ImageAssets,
|
|
||||||
num_crops: Optional[int]):
|
|
||||||
"""Ensure that the [default] input mapper handles num_crops properly."""
|
|
||||||
# We pass the processor kwargs here since for this model, we fall back to
|
|
||||||
# the default mapper; this will fall back to the HF mapper and forward
|
|
||||||
# mm_processor_kwargs to it.
|
|
||||||
mm_processor_kwargs = {
|
|
||||||
"num_crops": num_crops
|
|
||||||
} if num_crops is not None else {}
|
|
||||||
ctx = build_model_context(
|
|
||||||
model_name=model,
|
|
||||||
tokenizer_name=model,
|
|
||||||
trust_remote_code=True,
|
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hf_processor = AutoImageProcessor.from_pretrained(model,
|
|
||||||
trust_remote_code=True,
|
|
||||||
**mm_processor_kwargs)
|
|
||||||
|
|
||||||
mm_registry = MultiModalRegistry()
|
|
||||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
|
||||||
|
|
||||||
image = image_assets[0].pil_image
|
|
||||||
hf_result = hf_processor.preprocess(
|
|
||||||
image,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
vllm_result = mm_registry.map_input(
|
|
||||||
ctx.model_config,
|
|
||||||
{"image": image},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"])
|
|
||||||
assert torch.all(
|
|
||||||
hf_result["num_img_tokens"] == vllm_result["num_img_tokens"])
|
|
||||||
|
|
||||||
# For pixel values, the second axis should be the num_crops + 1
|
|
||||||
# for the rescaled original image. The default value in VLLM falls
|
|
||||||
# back to the HF config, which is why we compare to the processor num_crops
|
|
||||||
assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"])
|
|
||||||
assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
|
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
|
||||||
(4, 781),
|
(4, 781),
|
||||||
@ -112,48 +57,20 @@ def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [
|
@pytest.mark.parametrize(
|
||||||
(4, 781, 1),
|
"num_crops,expected_toks_per_img,num_imgs",
|
||||||
(4, 781, 2),
|
[
|
||||||
(16, 2653, 1),
|
(4, 757, 1),
|
||||||
(16, 2653, 2),
|
(4, 757, 2),
|
||||||
])
|
(16, 1921, 1),
|
||||||
def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
|
(16, 1921, 2),
|
||||||
toks_per_img: int, num_imgs: int):
|
# the default num_crops of phi-3.5-vision is 4
|
||||||
"""Ensure dummy_data_for_phi3v handles num_crops properly."""
|
(None, 757, 2),
|
||||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
(None, 757, 2),
|
||||||
# in this test and assume that the kwargs will be correctly expanded by
|
])
|
||||||
# the partial when calling the dummy data func.
|
def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets,
|
||||||
ctx = build_model_context(
|
model: str, num_crops: Optional[int],
|
||||||
model_name=model,
|
expected_toks_per_img: int, num_imgs: int):
|
||||||
tokenizer_name=model,
|
|
||||||
trust_remote_code=True,
|
|
||||||
mm_processor_kwargs=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_data = dummy_data_for_phi3v(
|
|
||||||
ctx=ctx,
|
|
||||||
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
|
|
||||||
mm_counts={"image": num_imgs},
|
|
||||||
num_crops=num_crops,
|
|
||||||
)
|
|
||||||
sequence_data = dummy_data.seq_data
|
|
||||||
# Ensure we have the right number of placeholders per num_crops size
|
|
||||||
img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID)
|
|
||||||
assert img_tok_count == toks_per_img * num_imgs
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", models)
|
|
||||||
@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [
|
|
||||||
(4, 757, 1),
|
|
||||||
(4, 757, 2),
|
|
||||||
(16, 1921, 1),
|
|
||||||
(16, 1921, 2),
|
|
||||||
])
|
|
||||||
def test_input_processor_override(input_processor_for_phi3v,
|
|
||||||
image_assets: _ImageAssets, model: str,
|
|
||||||
num_crops: int, expected_toks_per_img: int,
|
|
||||||
num_imgs: int):
|
|
||||||
"""Ensure input_processor_for_phi3v handles num_crops properly."""
|
"""Ensure input_processor_for_phi3v handles num_crops properly."""
|
||||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||||
# in this test and assume that the kwargs will be correctly expanded by
|
# in this test and assume that the kwargs will be correctly expanded by
|
||||||
@ -163,19 +80,20 @@ def test_input_processor_override(input_processor_for_phi3v,
|
|||||||
tokenizer_name=model,
|
tokenizer_name=model,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
||||||
|
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
||||||
# Build the image str / prompt based on the number of images we pass
|
# Build the image str / prompt based on the number of images we pass
|
||||||
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
|
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
|
||||||
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
|
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
|
||||||
images = [image_assets[0].pil_image] * num_imgs
|
images = [image_assets[0].pil_image] * num_imgs
|
||||||
|
|
||||||
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
|
mm_data = {"image": images}
|
||||||
prompt=prompt,
|
mm_processor_kwargs = {}
|
||||||
multi_modal_data={"image": images})
|
if num_crops is not None:
|
||||||
|
mm_processor_kwargs = {"num_crops": num_crops}
|
||||||
|
|
||||||
processed_inputs = input_processor_for_phi3v(ctx,
|
processor = processor_for_phi3v(ctx)
|
||||||
inputs,
|
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
|
||||||
num_crops=num_crops)
|
|
||||||
|
|
||||||
# Ensure we have the right number of placeholders per num_crops size
|
# Ensure we have the right number of placeholders per num_crops size
|
||||||
img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
|
img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
|
||||||
|
@ -15,13 +15,13 @@ from ..models.utils import build_model_context
|
|||||||
# Used for fast tests where the model doesn't matter
|
# Used for fast tests where the model doesn't matter
|
||||||
DUMMY_MODEL_ID = "facebook/opt-125m"
|
DUMMY_MODEL_ID = "facebook/opt-125m"
|
||||||
# Used for tests that need a multimodal model
|
# Used for tests that need a multimodal model
|
||||||
MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
MULTIMODAL_MODEL_ID = "OpenGVLab/InternVL2-2B"
|
||||||
|
|
||||||
# For mm_processor_kwargs - we test overrides by defining mocks for each place
|
# For mm_processor_kwargs - we test overrides by defining mocks for each place
|
||||||
# it is used, and ensuring that we can pass processor kwargs an override value
|
# it is used, and ensuring that we can pass processor kwargs an override value
|
||||||
# to receive the intended result for things like sequence length etc.
|
# to receive the intended result for things like sequence length etc.
|
||||||
DEFAULT_NUM_CROPS = 4
|
DEFAULT_MAX_DYNAMIC_PATCH = 6
|
||||||
NUM_CROPS_OVERRIDE = 16
|
MAX_DYNAMIC_PATCH_OVERRIDE = 4
|
||||||
|
|
||||||
|
|
||||||
# Mocks for all of the places that we use the mm_processor_kwargs
|
# Mocks for all of the places that we use the mm_processor_kwargs
|
||||||
@ -33,10 +33,11 @@ def use_processor_mock():
|
|||||||
def custom_processor(ctx: InputContext,
|
def custom_processor(ctx: InputContext,
|
||||||
inputs: DecoderOnlyInputs,
|
inputs: DecoderOnlyInputs,
|
||||||
*,
|
*,
|
||||||
num_crops=DEFAULT_NUM_CROPS):
|
max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH):
|
||||||
# For testing purposes, we don't worry about the prompt
|
# For testing purposes, we don't worry about the prompt
|
||||||
return token_inputs(prompt_token_ids=[],
|
return token_inputs(
|
||||||
mm_processor_kwargs={"num_crops": num_crops})
|
prompt_token_ids=[],
|
||||||
|
mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch})
|
||||||
|
|
||||||
with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
|
with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
|
||||||
return_value=custom_processor):
|
return_value=custom_processor):
|
||||||
@ -52,9 +53,9 @@ def use_dummy_data_mock():
|
|||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
*,
|
*,
|
||||||
num_crops=DEFAULT_NUM_CROPS):
|
max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH):
|
||||||
seq_data = SequenceData(
|
seq_data = SequenceData(
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
|
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * max_dynamic_patch))
|
||||||
return DummyData(seq_data, None)
|
return DummyData(seq_data, None)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
@ -65,15 +66,15 @@ def use_dummy_data_mock():
|
|||||||
|
|
||||||
# Lazy import to avoid CUDA reinitialization error
|
# Lazy import to avoid CUDA reinitialization error
|
||||||
def mm_model_cls():
|
def mm_model_cls():
|
||||||
from vllm.model_executor.models.phi3v import Phi3VForCausalLM
|
from vllm.model_executor.models.internvl import InternVLChatModel
|
||||||
|
|
||||||
return Phi3VForCausalLM
|
return InternVLChatModel
|
||||||
|
|
||||||
|
|
||||||
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
|
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
|
||||||
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
|
get_max_dynamic_patch = lambda ctx, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: max_dynamic_patch # noqa: E501
|
||||||
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
|
custom_mapper = lambda ctx, data, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: { # noqa: E501
|
||||||
"pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
|
"pixel_values": torch.zeros(size=(1, max_dynamic_patch + 1, 3, 448, 448))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -88,27 +89,28 @@ def test_default_processor_is_a_noop():
|
|||||||
assert proc_inputs is proc_outputs
|
assert proc_inputs is proc_outputs
|
||||||
|
|
||||||
|
|
||||||
def _get_num_crops_info(init_num_crops: int, inference_num_crops: int):
|
def _get_max_dynamic_patch_info(init_max_dynamic_patch: int,
|
||||||
"""Get the init / inference kwargs and expected num_crops for this test."""
|
inference_max_dynamic_patch: int):
|
||||||
# If we have a value for num_crops, pass the override value and make
|
"""Get the init / inference kwargs and expected max_dynamic_patch."""
|
||||||
|
# If we have a value for max_dynamic_patch, pass the override value and make
|
||||||
# sure we get that value as a return-value from out mock processor,
|
# sure we get that value as a return-value from out mock processor,
|
||||||
# otherwise fall back to the default value
|
# otherwise fall back to the default value
|
||||||
init_kwargs = None if init_num_crops is None else {
|
init_kwargs = None if init_max_dynamic_patch is None else {
|
||||||
"num_crops": init_num_crops
|
"max_dynamic_patch": init_max_dynamic_patch
|
||||||
}
|
}
|
||||||
inference_kwargs = None if inference_num_crops is None else {
|
inference_kwargs = None if inference_max_dynamic_patch is None else {
|
||||||
"num_crops": inference_num_crops
|
"max_dynamic_patch": inference_max_dynamic_patch
|
||||||
}
|
}
|
||||||
if inference_num_crops is not None:
|
if inference_max_dynamic_patch is not None:
|
||||||
expected_seq_count = inference_num_crops
|
expected_seq_count = inference_max_dynamic_patch
|
||||||
elif init_num_crops is not None:
|
elif init_max_dynamic_patch is not None:
|
||||||
expected_seq_count = init_num_crops
|
expected_seq_count = init_max_dynamic_patch
|
||||||
else:
|
else:
|
||||||
expected_seq_count = DEFAULT_NUM_CROPS
|
expected_seq_count = DEFAULT_MAX_DYNAMIC_PATCH
|
||||||
return init_kwargs, inference_kwargs, expected_seq_count
|
return init_kwargs, inference_kwargs, expected_seq_count
|
||||||
|
|
||||||
|
|
||||||
def _get_processed_num_crops(
|
def _get_processed_max_dynamic_patch(
|
||||||
processor: Callable[[ProcessorInputs], ProcessorInputs],
|
processor: Callable[[ProcessorInputs], ProcessorInputs],
|
||||||
inference_kwargs: Optional[Dict[str, int]],
|
inference_kwargs: Optional[Dict[str, int]],
|
||||||
) -> int:
|
) -> int:
|
||||||
@ -120,27 +122,30 @@ def _get_processed_num_crops(
|
|||||||
assert "type" in processed_inputs
|
assert "type" in processed_inputs
|
||||||
assert processed_inputs["type"] == "token"
|
assert processed_inputs["type"] == "token"
|
||||||
assert "mm_processor_kwargs" in processed_inputs
|
assert "mm_processor_kwargs" in processed_inputs
|
||||||
return processed_inputs["mm_processor_kwargs"]["num_crops"]
|
return processed_inputs["mm_processor_kwargs"]["max_dynamic_patch"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
|
@pytest.mark.parametrize(
|
||||||
(None, None),
|
"init_max_dynamic_patch,inference_max_dynamic_patch", [
|
||||||
(NUM_CROPS_OVERRIDE, None),
|
(None, None),
|
||||||
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
|
(MAX_DYNAMIC_PATCH_OVERRIDE, None),
|
||||||
])
|
(DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE),
|
||||||
def test_input_processor_kwargs(use_processor_mock, init_num_crops,
|
])
|
||||||
inference_num_crops):
|
def test_input_processor_kwargs(use_processor_mock, init_max_dynamic_patch,
|
||||||
|
inference_max_dynamic_patch):
|
||||||
"""Ensure input processors can use processor kwargs."""
|
"""Ensure input processors can use processor kwargs."""
|
||||||
dummy_registry = InputRegistry()
|
dummy_registry = InputRegistry()
|
||||||
|
|
||||||
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
|
(init_kwargs, inference_kwargs,
|
||||||
init_num_crops, inference_num_crops)
|
expected_seq_count) = _get_max_dynamic_patch_info(
|
||||||
|
init_max_dynamic_patch, inference_max_dynamic_patch)
|
||||||
|
|
||||||
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
|
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
|
||||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||||
num_crops_val = _get_processed_num_crops(processor, inference_kwargs)
|
max_dynamic_patch_val = _get_processed_max_dynamic_patch(
|
||||||
|
processor, inference_kwargs)
|
||||||
|
|
||||||
assert num_crops_val == expected_seq_count
|
assert max_dynamic_patch_val == expected_seq_count
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -165,18 +170,21 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
|
|||||||
|
|
||||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||||
# Should filter out the inference time kwargs
|
# Should filter out the inference time kwargs
|
||||||
num_crops_val = _get_processed_num_crops(processor, mm_processor_kwargs)
|
max_dynamic_patch_val = _get_processed_max_dynamic_patch(
|
||||||
assert num_crops_val == DEFAULT_NUM_CROPS
|
processor, mm_processor_kwargs)
|
||||||
|
assert max_dynamic_patch_val == DEFAULT_MAX_DYNAMIC_PATCH
|
||||||
|
|
||||||
|
|
||||||
### Test overrides for the dummy data
|
### Test overrides for the dummy data
|
||||||
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
@pytest.mark.parametrize("max_dynamic_patch",
|
||||||
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
|
[None, MAX_DYNAMIC_PATCH_OVERRIDE])
|
||||||
|
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, max_dynamic_patch):
|
||||||
"""Ensure dummy data factories can use processor kwargs."""
|
"""Ensure dummy data factories can use processor kwargs."""
|
||||||
mm_processor_kwargs = None if num_crops is None else {
|
mm_processor_kwargs = None if max_dynamic_patch is None else {
|
||||||
"num_crops": num_crops
|
"max_dynamic_patch": max_dynamic_patch
|
||||||
}
|
}
|
||||||
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH
|
||||||
|
if max_dynamic_patch is None else max_dynamic_patch)
|
||||||
dummy_registry = InputRegistry()
|
dummy_registry = InputRegistry()
|
||||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||||
mm_processor_kwargs=mm_processor_kwargs)
|
mm_processor_kwargs=mm_processor_kwargs)
|
||||||
@ -217,17 +225,20 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
|
|||||||
# len is solely dependent on the value of the mm_processor_kwargs.
|
# len is solely dependent on the value of the mm_processor_kwargs.
|
||||||
dummy_data = dummy_registry.dummy_data_for_profiling(
|
dummy_data = dummy_registry.dummy_data_for_profiling(
|
||||||
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
|
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
|
||||||
assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
|
assert len(
|
||||||
|
dummy_data.seq_data.prompt_token_ids) == DEFAULT_MAX_DYNAMIC_PATCH
|
||||||
|
|
||||||
|
|
||||||
### Test overrides for the max token count per multimodal instance
|
### Test overrides for the max token count per multimodal instance
|
||||||
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
@pytest.mark.parametrize("max_dynamic_patch",
|
||||||
def test_max_tokens_kwarg_overrides(num_crops):
|
[None, MAX_DYNAMIC_PATCH_OVERRIDE])
|
||||||
|
def test_max_tokens_kwarg_overrides(max_dynamic_patch):
|
||||||
"""Ensure max token calcs can use processor kwargs."""
|
"""Ensure max token calcs can use processor kwargs."""
|
||||||
mm_processor_kwargs = None if num_crops is None else {
|
mm_processor_kwargs = None if max_dynamic_patch is None else {
|
||||||
"num_crops": num_crops
|
"max_dynamic_patch": max_dynamic_patch
|
||||||
}
|
}
|
||||||
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH
|
||||||
|
if max_dynamic_patch is None else max_dynamic_patch)
|
||||||
|
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
task="generate",
|
task="generate",
|
||||||
@ -239,11 +250,11 @@ def test_max_tokens_kwarg_overrides(num_crops):
|
|||||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||||
# Patch the image registry for phi3v with our lambda that is compatible
|
# Patch the image registry for phi3v with our lambda that is compatible
|
||||||
# with overrides, then ensure that calling the method correctly echos
|
# with overrides, then ensure that calling the method correctly echos
|
||||||
# our num_crops value back from the mm_processor_kwargs.
|
# our max_dynamic_patch value back from the mm_processor_kwargs.
|
||||||
with patch.object(
|
with patch.object(
|
||||||
mm_registry._get_plugin("image"),
|
mm_registry._get_plugin("image"),
|
||||||
"_max_mm_tokens",
|
"_max_mm_tokens",
|
||||||
{mm_model_cls(): get_num_crops},
|
{mm_model_cls(): get_max_dynamic_patch},
|
||||||
):
|
):
|
||||||
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
|
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
|
||||||
ctx.model_config)
|
ctx.model_config)
|
||||||
@ -279,26 +290,29 @@ def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
|
|||||||
with patch.object(
|
with patch.object(
|
||||||
mm_registry._get_plugin("image"),
|
mm_registry._get_plugin("image"),
|
||||||
"_max_mm_tokens",
|
"_max_mm_tokens",
|
||||||
{mm_model_cls(): get_num_crops},
|
{mm_model_cls(): get_max_dynamic_patch},
|
||||||
):
|
):
|
||||||
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
|
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
|
||||||
ctx.model_config)
|
ctx.model_config)
|
||||||
|
|
||||||
assert max_multimodal_tokens == DEFAULT_NUM_CROPS
|
assert max_multimodal_tokens == DEFAULT_MAX_DYNAMIC_PATCH
|
||||||
|
|
||||||
|
|
||||||
### Test overrides for the mapper
|
### Test overrides for the mapper
|
||||||
@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE])
|
@pytest.mark.parametrize(
|
||||||
def test_default_mapper_with_processor_kwargs(image_assets, num_crops):
|
"max_dynamic_patch",
|
||||||
|
[DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE])
|
||||||
|
def test_default_mapper_with_processor_kwargs(image_assets, max_dynamic_patch):
|
||||||
"""Ensure that the mapper processor kwargs can fall back to HF models."""
|
"""Ensure that the mapper processor kwargs can fall back to HF models."""
|
||||||
# NOTE - we don't validate bad inputs for the default mapper, because it's
|
# NOTE - we don't validate bad inputs for the default mapper, because it's
|
||||||
# through the automodel interface in transformers, so we can't easily
|
# through the automodel interface in transformers, so we can't easily
|
||||||
# inspect what kwargs are or are not allowed.
|
# inspect what kwargs are or are not allowed.
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
ctx = build_model_context(
|
||||||
task="generate",
|
MULTIMODAL_MODEL_ID,
|
||||||
trust_remote_code=True,
|
task="generate",
|
||||||
mm_processor_kwargs={"num_crops": num_crops},
|
trust_remote_code=True,
|
||||||
limit_mm_per_prompt={"image": 1})
|
mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch},
|
||||||
|
limit_mm_per_prompt={"image": 1})
|
||||||
|
|
||||||
mm_registry = MultiModalRegistry()
|
mm_registry = MultiModalRegistry()
|
||||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||||
@ -307,20 +321,22 @@ def test_default_mapper_with_processor_kwargs(image_assets, num_crops):
|
|||||||
mm_inputs = {"image": image}
|
mm_inputs = {"image": image}
|
||||||
|
|
||||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
||||||
# Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336]
|
# pixel vals should have shape: [batch, max_dynamic_patch+1, ...]
|
||||||
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
|
assert mapped_inputs["pixel_values"].shape[1] == max_dynamic_patch + 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
|
@pytest.mark.parametrize(
|
||||||
(None, None),
|
"init_max_dynamic_patch,inference_max_dynamic_patch", [
|
||||||
(NUM_CROPS_OVERRIDE, None),
|
(None, None),
|
||||||
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
|
(MAX_DYNAMIC_PATCH_OVERRIDE, None),
|
||||||
])
|
(DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE),
|
||||||
def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
|
])
|
||||||
inference_num_crops):
|
def test_custom_mapper_kwarg_overrides(image_assets, init_max_dynamic_patch,
|
||||||
|
inference_max_dynamic_patch):
|
||||||
"""Ensure custom mappers can use processor kwargs."""
|
"""Ensure custom mappers can use processor kwargs."""
|
||||||
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
|
(init_kwargs, inference_kwargs,
|
||||||
init_num_crops, inference_num_crops)
|
expected_seq_count) = _get_max_dynamic_patch_info(
|
||||||
|
init_max_dynamic_patch, inference_max_dynamic_patch)
|
||||||
|
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
task="generate",
|
task="generate",
|
||||||
@ -335,7 +351,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
|
|||||||
|
|
||||||
# Patch the image registry for phi3v with our lambda that is compatible
|
# Patch the image registry for phi3v with our lambda that is compatible
|
||||||
# with overrides, then ensure that calling the method correctly echos
|
# with overrides, then ensure that calling the method correctly echos
|
||||||
# our num_crops value back from the mm_processor_kwargs.
|
# our max_dynamic_patch value back from the mm_processor_kwargs.
|
||||||
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
|
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
|
||||||
mm_model_cls())
|
mm_model_cls())
|
||||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
|
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
|
||||||
@ -373,11 +389,12 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
|
|||||||
|
|
||||||
# Patch the image registry for phi3v with our lambda that is compatible
|
# Patch the image registry for phi3v with our lambda that is compatible
|
||||||
# with overrides, then ensure that calling the method correctly echos
|
# with overrides, then ensure that calling the method correctly echos
|
||||||
# our num_crops value back from the mm_processor_kwargs.
|
# our max_dynamic_patch value back from the mm_processor_kwargs.
|
||||||
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
|
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
|
||||||
mm_model_cls())
|
mm_model_cls())
|
||||||
# Should filter out the inference time kwargs
|
# Should filter out the inference time kwargs
|
||||||
mapped_inputs = mm_registry.map_input(
|
mapped_inputs = mm_registry.map_input(
|
||||||
ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)
|
ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)
|
||||||
|
|
||||||
assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1
|
assert mapped_inputs["pixel_values"].shape[1] == (
|
||||||
|
DEFAULT_MAX_DYNAMIC_PATCH + 1)
|
||||||
|
@ -69,12 +69,12 @@ class InputProcessingContext(InputContext):
|
|||||||
tokenizer: AnyTokenizer
|
tokenizer: AnyTokenizer
|
||||||
"""The tokenizer used to tokenize the inputs."""
|
"""The tokenizer used to tokenize the inputs."""
|
||||||
|
|
||||||
def get_hf_processor(self) -> ProcessorMixin:
|
def get_hf_processor(self, **kwargs) -> ProcessorMixin:
|
||||||
return cached_get_processor(
|
return cached_get_processor(
|
||||||
self.model_config.tokenizer,
|
self.model_config.tokenizer,
|
||||||
tokenizer=self.tokenizer, # Override the tokenizer with ours
|
tokenizer=self.tokenizer, # Override the tokenizer with ours
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
)
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
N = TypeVar("N", bound=Type[nn.Module])
|
N = TypeVar("N", bound=Type[nn.Module])
|
||||||
|
@ -12,22 +12,18 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import itertools
|
from functools import cached_property
|
||||||
import re
|
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||||
from functools import cached_property, lru_cache
|
TypedDict, Union)
|
||||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
|
||||||
Tuple, TypedDict, Union)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
|
||||||
from transformers import CLIPVisionConfig, PretrainedConfig
|
ProcessorMixin)
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
from vllm.inputs import InputContext
|
||||||
InputContext, token_inputs)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
@ -36,12 +32,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
|
from vllm.multimodal.image import cached_get_image_processor
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
|
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||||
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
InputProcessingContext,
|
||||||
|
ModalityProcessingMetadata,
|
||||||
|
MultiModalDataDict,
|
||||||
|
MultiModalProcessingMetadata,
|
||||||
|
PromptReplacement)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
from .clip import dummy_image_for_clip
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix,
|
||||||
@ -303,231 +305,99 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
|||||||
return image_features_hd_newline
|
return image_features_hd_newline
|
||||||
|
|
||||||
|
|
||||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
|
|
||||||
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
|
||||||
target_height = int(np.ceil(height / padding_unit) * padding_unit)
|
|
||||||
top_padding = int((target_height - height) / 2)
|
|
||||||
bottom_padding = target_height - height - top_padding
|
|
||||||
padded_width = width
|
|
||||||
padded_height = height + top_padding + bottom_padding
|
|
||||||
return padded_width, padded_height
|
|
||||||
|
|
||||||
|
|
||||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
|
|
||||||
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int):
|
|
||||||
transposed = False
|
|
||||||
if width < height:
|
|
||||||
width, height = height, width
|
|
||||||
transposed = True
|
|
||||||
|
|
||||||
ratio = width / height
|
|
||||||
scale = 1
|
|
||||||
while scale * np.ceil(scale / ratio) <= hd_num:
|
|
||||||
scale += 1
|
|
||||||
scale -= 1
|
|
||||||
|
|
||||||
new_width = int(scale * 336)
|
|
||||||
new_height = int(new_width / ratio)
|
|
||||||
|
|
||||||
padded_width, padded_height = _calc_padded_size(width=new_width,
|
|
||||||
height=new_height)
|
|
||||||
|
|
||||||
if transposed:
|
|
||||||
padded_width, padded_height = padded_height, padded_width
|
|
||||||
|
|
||||||
return padded_width, padded_height
|
|
||||||
|
|
||||||
|
|
||||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
|
|
||||||
def get_phi3v_image_feature_size(
|
|
||||||
hf_config: Dict[str, Any],
|
|
||||||
*,
|
|
||||||
input_height: int,
|
|
||||||
input_width: int,
|
|
||||||
num_crops: int,
|
|
||||||
) -> int:
|
|
||||||
if num_crops is None:
|
|
||||||
num_crops = hf_config.get("num_crops", 16)
|
|
||||||
new_width, new_height = _calc_hd_transform_size(width=input_width,
|
|
||||||
height=input_height,
|
|
||||||
hd_num=num_crops)
|
|
||||||
|
|
||||||
return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
|
|
||||||
+ (new_height // 336 + 1) * 12
|
|
||||||
|
|
||||||
|
|
||||||
def get_max_phi3v_image_tokens(ctx: InputContext,
|
def get_max_phi3v_image_tokens(ctx: InputContext,
|
||||||
*,
|
*,
|
||||||
num_crops: Optional[int] = None):
|
num_crops: Optional[int] = None):
|
||||||
|
mm_processor_kwargs = {}
|
||||||
|
if num_crops is not None:
|
||||||
|
mm_processor_kwargs["num_crops"] = num_crops
|
||||||
|
|
||||||
return get_phi3v_image_feature_size(
|
model_config = ctx.model_config
|
||||||
ctx.get_hf_image_processor_config(),
|
image_processor = cached_get_image_processor(
|
||||||
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
model_config.model,
|
||||||
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
num_crops=num_crops,
|
**mm_processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
|
||||||
|
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||||
|
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||||
|
)
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
def dummy_data_for_phi3v(ctx: InputContext,
|
|
||||||
seq_len: int,
|
def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int]):
|
||||||
*,
|
|
||||||
num_crops: Optional[int] = None):
|
|
||||||
num_images = mm_counts["image"]
|
num_images = mm_counts["image"]
|
||||||
|
|
||||||
image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops)
|
data = dummy_image_for_clip(
|
||||||
|
|
||||||
seq_data, ranges = dummy_seq_data_for_clip(
|
|
||||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
|
||||||
seq_len,
|
|
||||||
num_images,
|
|
||||||
image_token_id=_IMAGE_TOKEN_ID,
|
|
||||||
image_feature_size_override=image_feature_size,
|
|
||||||
)
|
|
||||||
mm_data = dummy_image_for_clip(
|
|
||||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||||
num_images,
|
num_images,
|
||||||
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||||
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||||
)
|
)
|
||||||
|
|
||||||
return DummyData(seq_data, mm_data, ranges)
|
hf_processor = ctx.get_hf_processor()
|
||||||
|
image_processor = hf_processor.image_processor # type: ignore
|
||||||
|
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
|
||||||
|
|
||||||
|
return MultiModalKwargs(**hf_inputs)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
def create_metadata_for_phi3v(
|
||||||
def _get_image_placeholder_token_id_candidates(
|
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
|
||||||
model_config: ModelConfig,
|
return {
|
||||||
idx: int,
|
"image":
|
||||||
) -> List[List[int]]:
|
ModalityProcessingMetadata(prompt_repls=[
|
||||||
assert idx > 0
|
PromptReplacement(target=[_IMAGE_TOKEN_ID],
|
||||||
|
repl_unit=[_IMAGE_TOKEN_ID],
|
||||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
repl_count=get_max_phi3v_image_tokens(ctx)),
|
||||||
|
]),
|
||||||
# This is used when the image token is at the start of the string
|
}
|
||||||
start_candidate = tokenizer.encode(f"<|image_{idx}|>",
|
|
||||||
add_special_tokens=False)
|
|
||||||
|
|
||||||
# This is used when the image token is in the middle of the string
|
|
||||||
# We need to get the token for "<", not "▁<"
|
|
||||||
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
|
|
||||||
a_token_id, = tokenizer.encode("a", add_special_tokens=False)
|
|
||||||
a_token_id_, *middle_candidate = tokenizer.encode(f"a<|image_{idx}|>",
|
|
||||||
add_special_tokens=False)
|
|
||||||
assert a_token_id == a_token_id_
|
|
||||||
|
|
||||||
return [start_candidate, middle_candidate]
|
|
||||||
|
|
||||||
|
|
||||||
def input_processor_for_phi3v(ctx: InputContext,
|
class Phi3VProcessor(BaseMultiModalProcessor):
|
||||||
inputs: DecoderOnlyInputs,
|
|
||||||
*,
|
|
||||||
num_crops: Optional[int] = None):
|
|
||||||
multi_modal_data = inputs.get("multi_modal_data")
|
|
||||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
model_config = ctx.model_config
|
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||||
hf_config = ctx.get_hf_image_processor_config()
|
super().__init__(
|
||||||
|
ctx=ctx,
|
||||||
|
metadata=create_metadata_for_phi3v(ctx),
|
||||||
|
)
|
||||||
|
|
||||||
image_data = multi_modal_data["image"]
|
def _get_hf_processor(
|
||||||
if isinstance(image_data, Image.Image):
|
self,
|
||||||
w, h = image_data.size
|
*,
|
||||||
image_feature_size = [
|
num_crops: Optional[int] = None,
|
||||||
get_phi3v_image_feature_size(hf_config,
|
) -> ProcessorMixin:
|
||||||
input_width=w,
|
if num_crops is not None:
|
||||||
input_height=h,
|
return self.ctx.get_hf_processor(num_crops=num_crops)
|
||||||
num_crops=num_crops)
|
return self.ctx.get_hf_processor()
|
||||||
]
|
|
||||||
image_data = [image_data]
|
|
||||||
elif is_list_of(image_data, Image.Image):
|
|
||||||
image_feature_size = []
|
|
||||||
for image in image_data:
|
|
||||||
w, h = image.size
|
|
||||||
image_feature_size.append(
|
|
||||||
get_phi3v_image_feature_size(hf_config,
|
|
||||||
input_width=w,
|
|
||||||
input_height=h,
|
|
||||||
num_crops=num_crops))
|
|
||||||
elif isinstance(image_data, torch.Tensor):
|
|
||||||
image_feature_size = [image_data.shape[0]]
|
|
||||||
image_data = [image_data]
|
|
||||||
elif is_list_of(image_data, torch.Tensor):
|
|
||||||
image_feature_size = [item.shape[0] for item in image_data]
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
|
||||||
|
|
||||||
prompt = inputs.get("prompt")
|
def _apply_hf_processor(
|
||||||
if prompt is None:
|
self,
|
||||||
# for async server request, we assume prompt and its token_ids is always
|
prompt: str,
|
||||||
# in correct format. And num_image_tags == len(image_data) always True.
|
mm_data: MultiModalDataDict,
|
||||||
image_idx = range(1, len(image_data) + 1)
|
mm_processor_kwargs: Mapping[str, object],
|
||||||
new_prompt = None
|
) -> BatchFeature:
|
||||||
else:
|
processed_outputs = super()._apply_hf_processor(
|
||||||
image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt)))
|
prompt, mm_data, mm_processor_kwargs)
|
||||||
if prompt.count("<|image|>") > 0:
|
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
|
||||||
logger.warning("Please follow the prompt format that is "
|
# which will cause OverflowError when decoding the prompt_ids.
|
||||||
"documented on HuggingFace which does not involve "
|
# Therefore, we need to do an early replacement here
|
||||||
"repeating <|image|> tokens.")
|
token_ids = processed_outputs['input_ids']
|
||||||
elif (num_image_tags := len(image_idx)) > 1:
|
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
|
||||||
assert num_image_tags == len(
|
processed_outputs['input_ids'] = token_ids
|
||||||
image_data), "The count of image_placeholder not match image's"
|
return processed_outputs
|
||||||
new_prompt = prompt
|
|
||||||
|
|
||||||
prompt_token_ids = inputs["prompt_token_ids"].copy()
|
def _get_dummy_mm_kwargs(
|
||||||
|
self,
|
||||||
# masked placeholder with image token id
|
mm_counts: Mapping[str, int],
|
||||||
for idx in image_idx:
|
) -> MultiModalKwargs:
|
||||||
candidates = _get_image_placeholder_token_id_candidates(model_config,
|
return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts)
|
||||||
idx=idx)
|
|
||||||
|
|
||||||
for candidate in candidates:
|
|
||||||
for i in range(len(prompt_token_ids) - len(candidate) + 1):
|
|
||||||
if prompt_token_ids[i:i + len(candidate)] == candidate:
|
|
||||||
prompt_token_ids[i:i +
|
|
||||||
len(candidate)] = ([_IMAGE_TOKEN_ID] *
|
|
||||||
len(candidate))
|
|
||||||
break
|
|
||||||
|
|
||||||
# merge consecutive tag ids
|
|
||||||
merged_token_ids: List[int] = []
|
|
||||||
for is_placeholder, token_ids in itertools.groupby(
|
|
||||||
prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID):
|
|
||||||
if is_placeholder:
|
|
||||||
merged_token_ids.append(_IMAGE_TOKEN_ID)
|
|
||||||
else:
|
|
||||||
merged_token_ids.extend(list(token_ids))
|
|
||||||
|
|
||||||
# TODO: Move this to utils or integrate with clip.
|
|
||||||
new_token_ids: List[int] = []
|
|
||||||
placeholder_ranges: List[PlaceholderRange] = []
|
|
||||||
placeholder_idx = 0
|
|
||||||
while merged_token_ids:
|
|
||||||
token_id = merged_token_ids.pop(0)
|
|
||||||
if token_id == _IMAGE_TOKEN_ID:
|
|
||||||
replacement_ids = repeat_and_pad_token(
|
|
||||||
_IMAGE_TOKEN_ID,
|
|
||||||
repeat_count=image_feature_size[placeholder_idx],
|
|
||||||
)
|
|
||||||
placeholder_ranges.append({
|
|
||||||
"offset": len(new_token_ids),
|
|
||||||
"length": len(replacement_ids)
|
|
||||||
})
|
|
||||||
new_token_ids.extend(replacement_ids)
|
|
||||||
placeholder_idx += 1
|
|
||||||
else:
|
|
||||||
new_token_ids.append(token_id)
|
|
||||||
|
|
||||||
# NOTE: Create a defensive copy of the original inputs
|
|
||||||
return token_inputs(prompt_token_ids=new_token_ids,
|
|
||||||
prompt=new_prompt,
|
|
||||||
multi_modal_data=multi_modal_data,
|
|
||||||
multi_modal_placeholders={"image": placeholder_ranges})
|
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
|
@MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor)
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
|
|
||||||
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
@ -3,7 +3,8 @@ from abc import ABC, abstractmethod
|
|||||||
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
|
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union
|
from typing import (Any, Dict, Generic, NamedTuple, Optional, Protocol,
|
||||||
|
TypeVar, Union, cast)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import BatchFeature, ProcessorMixin
|
from transformers import BatchFeature, ProcessorMixin
|
||||||
@ -11,7 +12,8 @@ from typing_extensions import TypeAlias, TypedDict
|
|||||||
|
|
||||||
from vllm.inputs import DummyData, InputProcessingContext
|
from vllm.inputs import DummyData, InputProcessingContext
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
|
from vllm.utils import (flatten_2d_lists, full_groupby, is_list_of,
|
||||||
|
resolve_mm_processor_kwargs)
|
||||||
|
|
||||||
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
|
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
|
||||||
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
|
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
|
||||||
@ -543,8 +545,14 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
|
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
|
self.init_mm_processor_kwargs = (ctx.model_config.mm_processor_kwargs
|
||||||
|
or {})
|
||||||
|
|
||||||
def _get_hf_processor(self) -> ProcessorMixin:
|
def _get_hf_processor(
|
||||||
|
self,
|
||||||
|
**mm_processor_kwargs: Mapping[str, object],
|
||||||
|
) -> ProcessorMixin:
|
||||||
|
# by default, we won't pass any kwargs to the processor initialization
|
||||||
return self.ctx.get_hf_processor()
|
return self.ctx.get_hf_processor()
|
||||||
|
|
||||||
def _get_tokenizer(self) -> AnyTokenizer:
|
def _get_tokenizer(self) -> AnyTokenizer:
|
||||||
@ -581,7 +589,13 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
mm_data: MultiModalDataDict,
|
mm_data: MultiModalDataDict,
|
||||||
mm_processor_kwargs: Mapping[str, object],
|
mm_processor_kwargs: Mapping[str, object],
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
hf_processor = self._get_hf_processor()
|
# some mm_processor_kwargs may be used in processor initialization
|
||||||
|
# instead of processor call
|
||||||
|
processor_init_kwargs = {
|
||||||
|
**self.init_mm_processor_kwargs,
|
||||||
|
**mm_processor_kwargs,
|
||||||
|
}
|
||||||
|
hf_processor = self._get_hf_processor(**processor_init_kwargs)
|
||||||
|
|
||||||
processor_data = dict[str, Any]()
|
processor_data = dict[str, Any]()
|
||||||
passthrough_data = dict[str, Any]()
|
passthrough_data = dict[str, Any]()
|
||||||
@ -601,6 +615,13 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
else:
|
else:
|
||||||
processor_data[k] = v
|
processor_data[k] = v
|
||||||
|
|
||||||
|
# filter mm_processor_kwargs used in processor call
|
||||||
|
mm_processor_kwargs = resolve_mm_processor_kwargs(
|
||||||
|
self.init_mm_processor_kwargs,
|
||||||
|
cast(Dict[str, Any], mm_processor_kwargs),
|
||||||
|
hf_processor,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
hf_inputs = hf_processor(
|
hf_inputs = hf_processor(
|
||||||
text=prompt, # type: ignore
|
text=prompt, # type: ignore
|
||||||
|
Loading…
x
Reference in New Issue
Block a user