[Bugfix] Fix various bugs in multi-modal processor (#12031)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
ff39141a49
commit
bb354e6b2d
@ -421,6 +421,8 @@ def test_find_replace_tokens(
|
|||||||
"pattern_1": [32000, 32000],
|
"pattern_1": [32000, 32000],
|
||||||
"pattern_2": [],
|
"pattern_2": [],
|
||||||
"pattern_3": [1550, 918, 1550],
|
"pattern_3": [1550, 918, 1550],
|
||||||
|
# Test different modalities having the same tokens (32000)
|
||||||
|
"pattern_4": [32000],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -438,6 +440,14 @@ def test_find_replace_tokens(
|
|||||||
replacement=[32000, 32000],
|
replacement=[32000, 32000],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
"pattern_4": [
|
||||||
|
PlaceholderInfo(
|
||||||
|
modality="pattern_4",
|
||||||
|
item_idx=0,
|
||||||
|
start_idx=3,
|
||||||
|
replacement=[32000],
|
||||||
|
),
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
),
|
),
|
||||||
@ -466,6 +476,7 @@ def test_find_replace_tokens(
|
|||||||
replacement=[1550, 918, 1550],
|
replacement=[1550, 918, 1550],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
# No match for pattern_4 as it has lower priority than pattern_1
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
@ -485,6 +496,14 @@ def test_find_replace_tokens(
|
|||||||
replacement=[32000, 32000],
|
replacement=[32000, 32000],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
"pattern_4": [
|
||||||
|
PlaceholderInfo(
|
||||||
|
modality="pattern_4",
|
||||||
|
item_idx=0,
|
||||||
|
start_idx=5,
|
||||||
|
replacement=[32000],
|
||||||
|
),
|
||||||
|
],
|
||||||
"pattern_3": [
|
"pattern_3": [
|
||||||
PlaceholderInfo(
|
PlaceholderInfo(
|
||||||
modality="pattern_3",
|
modality="pattern_3",
|
||||||
|
@ -404,71 +404,60 @@ def replace_text_matches(
|
|||||||
return "".join(texts)
|
return "".join(texts)
|
||||||
|
|
||||||
|
|
||||||
def _iter_modality_placeholders(
|
|
||||||
prompt: list[int],
|
|
||||||
modality: str,
|
|
||||||
modality_repls: Sequence[BoundPromptReplacement],
|
|
||||||
modal_item_count: int,
|
|
||||||
) -> Iterable[PlaceholderInfo]:
|
|
||||||
if modal_item_count == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
prompt_len = len(prompt)
|
|
||||||
item_idx = 0
|
|
||||||
|
|
||||||
start_idx = 0
|
|
||||||
while start_idx < prompt_len:
|
|
||||||
found = False
|
|
||||||
|
|
||||||
for repl_info in modality_repls:
|
|
||||||
replacement = repl_info.get_replacement(item_idx)
|
|
||||||
repl_tokens = replacement.token_ids
|
|
||||||
repl_len = len(repl_tokens)
|
|
||||||
end_idx = start_idx + repl_len
|
|
||||||
|
|
||||||
if repl_len == 0 or end_idx > prompt_len:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if prompt[start_idx:end_idx] == repl_tokens:
|
|
||||||
yield PlaceholderInfo(
|
|
||||||
modality=modality,
|
|
||||||
item_idx=item_idx,
|
|
||||||
start_idx=start_idx,
|
|
||||||
replacement=repl_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
item_idx += 1
|
|
||||||
if item_idx >= modal_item_count:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Exclude overlapping matches
|
|
||||||
start_idx = end_idx
|
|
||||||
found = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not found:
|
|
||||||
start_idx += 1
|
|
||||||
|
|
||||||
|
|
||||||
def _iter_placeholders(
|
def _iter_placeholders(
|
||||||
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> Iterable[PlaceholderInfo]:
|
) -> Iterable[PlaceholderInfo]:
|
||||||
"""
|
"""
|
||||||
For each modality, yield each set of placeholder tokens found in
|
Yield each set of placeholder tokens found in :code:`prompt`.
|
||||||
:code:`prompt`.
|
|
||||||
|
Matches are exclusive even when multiple modalities share
|
||||||
|
the same placeholder tokens. In that case, the modality that
|
||||||
|
appears earlier in `mm_prompt_repls` takes priority.
|
||||||
|
|
||||||
Note that empty matches are ignored.
|
Note that empty matches are ignored.
|
||||||
"""
|
"""
|
||||||
for modality, modal_item_count in mm_item_counts.items():
|
prompt_len = len(prompt)
|
||||||
if modality in mm_prompt_repls:
|
item_idx_by_modality = defaultdict[str, int](lambda: 0)
|
||||||
yield from _iter_modality_placeholders(
|
|
||||||
prompt,
|
start_idx = 0
|
||||||
modality,
|
while start_idx < prompt_len:
|
||||||
mm_prompt_repls[modality],
|
found = False
|
||||||
modal_item_count,
|
|
||||||
)
|
for modality, modality_repls in mm_prompt_repls.items():
|
||||||
|
item_idx = item_idx_by_modality[modality]
|
||||||
|
if item_idx >= mm_item_counts.get(modality, 0):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for repl_info in modality_repls:
|
||||||
|
replacement = repl_info.get_replacement(item_idx)
|
||||||
|
repl_tokens = replacement.token_ids
|
||||||
|
repl_len = len(repl_tokens)
|
||||||
|
end_idx = start_idx + repl_len
|
||||||
|
|
||||||
|
if repl_len == 0 or end_idx > prompt_len:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if prompt[start_idx:end_idx] == repl_tokens:
|
||||||
|
yield PlaceholderInfo(
|
||||||
|
modality=modality,
|
||||||
|
item_idx=item_idx,
|
||||||
|
start_idx=start_idx,
|
||||||
|
replacement=repl_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exclude overlapping matches
|
||||||
|
start_idx = end_idx
|
||||||
|
item_idx_by_modality[modality] += 1
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if found:
|
||||||
|
break # Go back to the outer while loop
|
||||||
|
|
||||||
|
if not found:
|
||||||
|
start_idx += 1
|
||||||
|
|
||||||
|
|
||||||
def find_mm_placeholders(
|
def find_mm_placeholders(
|
||||||
@ -1156,7 +1145,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
|
|
||||||
# If HF processor already inserts placeholder tokens,
|
# If HF processor already inserts placeholder tokens,
|
||||||
# there is no need for us to insert them
|
# there is no need for us to insert them
|
||||||
if all(len(repls) == 0 for repls in mm_missing_repls.items()):
|
if all(len(repls) == 0 for repls in mm_missing_repls.values()):
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
prompt = decode_tokens(tokenizer, prompt_ids)
|
prompt = decode_tokens(tokenizer, prompt_ids)
|
||||||
mm_placeholders = hf_mm_placeholders
|
mm_placeholders = hf_mm_placeholders
|
||||||
|
@ -259,7 +259,10 @@ class MultiModalRegistry:
|
|||||||
This is currently directly used only in V1.
|
This is currently directly used only in V1.
|
||||||
"""
|
"""
|
||||||
if self.has_processor(model_config):
|
if self.has_processor(model_config):
|
||||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
tokenizer = cached_get_tokenizer(
|
||||||
|
model_config.tokenizer,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
)
|
||||||
processor = self.create_processor(model_config, tokenizer)
|
processor = self.create_processor(model_config, tokenizer)
|
||||||
seq_len = model_config.max_model_len
|
seq_len = model_config.max_model_len
|
||||||
return processor.info.get_mm_max_tokens_per_item(seq_len)
|
return processor.info.get_mm_max_tokens_per_item(seq_len)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user