[VLM] Enable tokenized inputs for merged multi-modal processor (#11900)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-10 11:24:00 +08:00 committed by GitHub
parent c3cf54dda4
commit b844b99ad3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 207 additions and 77 deletions

View File

@ -649,7 +649,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
)
def _test_processing_cache_correctness(
def _test_processing_correctness(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
@ -691,6 +691,7 @@ def _test_processing_cache_correctness(
baseline_processor = factories.build_processor(ctx, cache=None)
cached_processor = factories.build_processor(ctx, cache=cache)
dummy_inputs = baseline_processor.dummy_inputs
tokenizer = baseline_processor.info.get_tokenizer()
rng = np.random.RandomState(0)
@ -747,7 +748,25 @@ def _test_processing_cache_correctness(
)
assert baseline_result == cached_result, (
f"Failed ({batch_idx=}, {mm_data=})")
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
baseline_tokenized_result = baseline_processor.apply(
tokenizer.encode(prompt),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert baseline_result == baseline_tokenized_result, (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
cached_tokenized_result = cached_processor.apply(
tokenizer.encode(prompt),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert cached_result == cached_tokenized_result, (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
# yapf: disable
@ -771,14 +790,14 @@ def _test_processing_cache_correctness(
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness(
def test_processing_correctness(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
_test_processing_cache_correctness(
_test_processing_correctness(
model_id,
modalities,
hit_rate=hit_rate,
@ -795,7 +814,7 @@ def test_processing_cache_correctness(
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness_phi3v(
def test_processing_correctness_phi3v(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
@ -809,7 +828,7 @@ def test_processing_cache_correctness_phi3v(
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
_test_processing_cache_correctness(
_test_processing_correctness(
model_id,
modalities,
hit_rate=hit_rate,

View File

@ -44,13 +44,13 @@ class TokensPrompt(TypedDict):
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
DEPRECATED: Optional multi-modal data to pass to the model,
Optional multi-modal data to pass to the model,
if the model supports it.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.

View File

@ -279,10 +279,6 @@ class InputPreprocessor:
mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer)
if isinstance(prompt, list):
logger.warning("Passing `multi_modal_data` in TokensPrompt is"
"deprecated and will be removed in a future update")
prompt = tokenizer.decode(prompt)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}

View File

@ -441,6 +441,24 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
# HF processor always adds placeholders even when there's no image
tokenizer = self.info.get_tokenizer()
prompt_ids = tokenizer.encode(prompt)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
@ -469,11 +487,11 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <image> tokens should be considered as placeholders,
# so we ignore the trailing bos_token

View File

@ -99,6 +99,34 @@ class ChameleonDummyInputsBuilder(
class ChameleonMultiModalProcessor(
BaseMultiModalProcessor[ChameleonProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor adds sep token for chat mode
tokenizer = self.info.get_tokenizer()
sep_token_id: int = \
tokenizer.vocab[tokenizer.sep_token] # type: ignore
return prompt_tokens + [sep_token_id]
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
@ -128,11 +156,11 @@ class ChameleonMultiModalProcessor(
def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <image> tokens should be considered as placeholders,
# so we ignore the image_start_token and image_end_token

View File

@ -16,7 +16,7 @@
""" PyTorch Fuyu model."""
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict)
TypedDict, Union)
import torch
import torch.nn as nn
@ -149,14 +149,10 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
# Avoid warning from HF logger for text-only input
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
# Tokenizer won't add boa_token_id by default, we add it manually.
tokenizer = self.info.get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
processed_outputs = super()._call_hf_processor(
@ -181,6 +177,16 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
return processed_outputs
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor adds boa_token_id
tokenizer = self.info.get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
return prompt_tokens + [boa_token_id]
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
@ -223,11 +229,11 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only |SPEAKER| (image) tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id

View File

@ -39,13 +39,13 @@ class SupportsMultiModal(Protocol):
The output embeddings must be one of the following formats:
- A list or tuple of 2D tensors, where each tensor corresponds to
each input multimodal data item (e.g, image).
- A list or tuple of 2D tensors, where each tensor corresponds to
each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
Note:
The returned multimodal embeddings must be in the same order as
the appearances of their corresponding multimodal data item in the
The returned multimodal embeddings must be in the same order as
the appearances of their corresponding multimodal data item in the
input prompt.
"""
...

View File

@ -724,7 +724,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
@ -737,7 +737,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
image_height=-1,
)
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts()
@ -760,7 +760,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
)
])
prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
prompt_ids, prompt, _ = self._apply_prompt_replacements(
result["prompt_token_ids"],
mantis_mm_repls,
mm_item_counts,
@ -788,7 +788,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
return MultiModalInputsV2(
type="multimodal",
prompt=prompt_text,
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholder_ranges,

View File

@ -481,11 +481,11 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <|image|> tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id

View File

@ -138,12 +138,8 @@ class UltravoxMultiModalProcessor(
) -> BatchFeature:
# Text-only input not supported in composite processor
if not mm_data:
tokenizer = self.info.get_tokenizer()
prompt_ids = tokenizer.encode(
prompt,
add_special_tokens=False, # type: ignore
)
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
mm_data = dict(mm_data)
@ -188,6 +184,16 @@ class UltravoxMultiModalProcessor(
)
return BatchFeature(combined_outputs)
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor omits bos_token_id by setting add_special_tokens=False
tokenizer = self.info.get_tokenizer()
assert prompt_tokens[0] == tokenizer.bos_token_id
return prompt_tokens[1:]
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,

View File

@ -725,15 +725,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs,
)
def _apply_hf_processor(
def _apply_hf_processor_text_mm(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]:
"""
Wrapper of :meth:`_call_hf_processor` that applies
additional pre-processing and post-processing.
Apply the HF processor on the prompt text and multi-modal data
together.
"""
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
@ -753,40 +753,93 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return prompt_ids, mm_kwargs
def _apply_hf_processor_missing(
self,
prompt_text: str,
mm_missing_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
):
def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
"""
Apply the HF processor on the full prompt text, but only on the
multi-modal data that are missing from the cache.
Apply the HF processor on the prompt text only.
Note:
We pass prompt text and multi-modal data into the HF processor
in separate calls to avoid HF prompt replacement being done for
cached items; instead, we rely on our own prompt replacement logic
(:meth:`_get_prompt_replacements`) for the full text.
Since HF processor requires that text and multi-modal items
correspond to each other, we create dummy multi-modal items
to go along with the text.
"""
mm_missing_counts = mm_missing_data_items.get_all_counts()
prompt_ids, _ = self._apply_hf_processor(
prompt_ids, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt_text,
mm_items=MultiModalDataItems({}),
hf_processor_mm_kwargs={},
)
# Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text
return prompt_ids
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
"""
Apply the HF processor on the prompt tokens only.
Most HF processors accept prompt text but not prompt tokens.
If the HF processor adds or removes tokens that are not related to
multi-modal data, you should override this method so it is consistent
with the output of :meth:`_apply_hf_processor_text_only` on the
corresponding text.
"""
return prompt_tokens
def _apply_hf_processor_mm_only(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalKwargs:
"""
Apply the HF processor on the multi-modal data only.
Since HF processor requires that text and multi-modal items
correspond to each other, we generate dummy text using
:class:`DummyInputsBuilder` to go along with the multi-modal data.
"""
mm_counts = mm_items.get_all_counts()
dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
self.info.ctx.model_config.max_model_len,
mm_missing_counts,
mm_counts,
)
_, mm_missing_kwargs = self._apply_hf_processor(
_, mm_kwargs = self._apply_hf_processor_text_mm(
prompt_text=dummy_inputs.prompt_text,
mm_items=mm_missing_data_items,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
return mm_kwargs
def _apply_hf_processor_main(
self,
prompt: Union[str, list[int]],
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
*,
enable_hf_prompt_replacement: bool,
) -> tuple[list[int], MultiModalKwargs]:
"""
Apply the HF processor on the prompt text and multi-modal data.
Note:
If :code:`enable_hf_prompt_replacement=False`, the prompt should
correspond to the multi-modal items.
"""
if isinstance(prompt, str):
if enable_hf_prompt_replacement:
return self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
prompt_ids = self._apply_hf_processor_text_only(prompt)
else:
prompt_ids = self._apply_hf_processor_tokens_only(prompt)
mm_missing_kwargs = self._apply_hf_processor_mm_only(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
@ -794,7 +847,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _cached_apply_hf_processor(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]:
@ -807,10 +860,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
if cache is None or passthrough_data:
return self._apply_hf_processor(
prompt_text=prompt_text,
return self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_replacement=True,
)
mm_maybe_cached_kw_items = {
@ -832,10 +886,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
}
mm_missing_data_items = self._to_mm_items(mm_missing_data)
prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing(
prompt_text=prompt_text,
mm_missing_data_items=mm_missing_data_items,
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we need to pass `enable_hf_prompt_replacement=False`
prompt_ids, mm_missing_kwargs = self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_replacement=False,
)
mm_missing_next_idx = {
@ -1018,7 +1075,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def apply(
self,
prompt_text: str,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
@ -1056,7 +1113,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_hashes = None
prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
prompt_text,
prompt,
mm_items,
hf_processor_mm_kwargs,
)
@ -1101,12 +1158,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# there is no need for us to insert them
if all(len(repls) == 0 for repls in mm_missing_repls.items()):
tokenizer = self.info.get_tokenizer()
prompt_text = decode_tokens(tokenizer, prompt_ids)
prompt = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders
else:
(
prompt_ids,
prompt_text,
prompt,
missing_mm_placeholders,
) = self._apply_prompt_replacements(
prompt_ids,
@ -1125,7 +1182,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return MultiModalInputsV2(
type="multimodal",
prompt=prompt_text,
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,

View File

@ -137,7 +137,7 @@ class MultiModalProfiler(Generic[_I]):
seq_len, mm_counts)
return self.processor.apply(
prompt_text=processor_inputs.prompt_text,
prompt=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)