[Bugfix] Fix various bugs in multi-modal processor (#12031)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-14 20:16:11 +08:00 committed by GitHub
parent ff39141a49
commit bb354e6b2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 69 additions and 58 deletions

View File

@ -421,6 +421,8 @@ def test_find_replace_tokens(
"pattern_1": [32000, 32000],
"pattern_2": [],
"pattern_3": [1550, 918, 1550],
# Test different modalities having the same tokens (32000)
"pattern_4": [32000],
},
],
)
@ -438,6 +440,14 @@ def test_find_replace_tokens(
replacement=[32000, 32000],
),
],
"pattern_4": [
PlaceholderInfo(
modality="pattern_4",
item_idx=0,
start_idx=3,
replacement=[32000],
),
],
}
),
@ -466,6 +476,7 @@ def test_find_replace_tokens(
replacement=[1550, 918, 1550],
),
],
# No match for pattern_4 as it has lower priority than pattern_1
}
),
(
@ -485,6 +496,14 @@ def test_find_replace_tokens(
replacement=[32000, 32000],
),
],
"pattern_4": [
PlaceholderInfo(
modality="pattern_4",
item_idx=0,
start_idx=5,
replacement=[32000],
),
],
"pattern_3": [
PlaceholderInfo(
modality="pattern_3",

View File

@ -404,71 +404,60 @@ def replace_text_matches(
return "".join(texts)
def _iter_modality_placeholders(
prompt: list[int],
modality: str,
modality_repls: Sequence[BoundPromptReplacement],
modal_item_count: int,
) -> Iterable[PlaceholderInfo]:
if modal_item_count == 0:
return
prompt_len = len(prompt)
item_idx = 0
start_idx = 0
while start_idx < prompt_len:
found = False
for repl_info in modality_repls:
replacement = repl_info.get_replacement(item_idx)
repl_tokens = replacement.token_ids
repl_len = len(repl_tokens)
end_idx = start_idx + repl_len
if repl_len == 0 or end_idx > prompt_len:
continue
if prompt[start_idx:end_idx] == repl_tokens:
yield PlaceholderInfo(
modality=modality,
item_idx=item_idx,
start_idx=start_idx,
replacement=repl_tokens,
)
item_idx += 1
if item_idx >= modal_item_count:
return
# Exclude overlapping matches
start_idx = end_idx
found = True
break
if not found:
start_idx += 1
def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Iterable[PlaceholderInfo]:
"""
For each modality, yield each set of placeholder tokens found in
:code:`prompt`.
Yield each set of placeholder tokens found in :code:`prompt`.
Matches are exclusive even when multiple modalities share
the same placeholder tokens. In that case, the modality that
appears earlier in `mm_prompt_repls` takes priority.
Note that empty matches are ignored.
"""
for modality, modal_item_count in mm_item_counts.items():
if modality in mm_prompt_repls:
yield from _iter_modality_placeholders(
prompt,
modality,
mm_prompt_repls[modality],
modal_item_count,
)
prompt_len = len(prompt)
item_idx_by_modality = defaultdict[str, int](lambda: 0)
start_idx = 0
while start_idx < prompt_len:
found = False
for modality, modality_repls in mm_prompt_repls.items():
item_idx = item_idx_by_modality[modality]
if item_idx >= mm_item_counts.get(modality, 0):
continue
for repl_info in modality_repls:
replacement = repl_info.get_replacement(item_idx)
repl_tokens = replacement.token_ids
repl_len = len(repl_tokens)
end_idx = start_idx + repl_len
if repl_len == 0 or end_idx > prompt_len:
continue
if prompt[start_idx:end_idx] == repl_tokens:
yield PlaceholderInfo(
modality=modality,
item_idx=item_idx,
start_idx=start_idx,
replacement=repl_tokens,
)
# Exclude overlapping matches
start_idx = end_idx
item_idx_by_modality[modality] += 1
found = True
break
if found:
break # Go back to the outer while loop
if not found:
start_idx += 1
def find_mm_placeholders(
@ -1156,7 +1145,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
if all(len(repls) == 0 for repls in mm_missing_repls.items()):
if all(len(repls) == 0 for repls in mm_missing_repls.values()):
tokenizer = self.info.get_tokenizer()
prompt = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders

View File

@ -259,7 +259,10 @@ class MultiModalRegistry:
This is currently directly used only in V1.
"""
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(model_config.tokenizer)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
processor = self.create_processor(model_config, tokenizer)
seq_len = model_config.max_model_len
return processor.info.get_mm_max_tokens_per_item(seq_len)