[Core] Don't use cache during multi-modal profiling (#14336)

This commit is contained in:
Cyrus Leung 2025-03-07 00:03:31 +08:00 committed by GitHub
parent caac5c2e59
commit 82551ad616
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 5 deletions

View File

@ -331,7 +331,9 @@ class InputRegistry:
if mm_registry.has_processor(model_config):
tokenizer = cached_tokenizer_from_config(model_config)
processor = mm_registry.create_processor(model_config, tokenizer)
processor = mm_registry.create_processor(model_config,
tokenizer,
disable_cache=True)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data(
seq_len, is_encoder_data=is_encoder_data)

View File

@ -257,7 +257,9 @@ class MultiModalRegistry:
"""
if self.has_processor(model_config):
tokenizer = cached_tokenizer_from_config(model_config)
processor = self.create_processor(model_config, tokenizer)
processor = self.create_processor(model_config,
tokenizer,
disable_cache=True)
seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config)
return processor.info.get_mm_max_tokens_per_item(
@ -372,7 +374,9 @@ class MultiModalRegistry:
"""
if self.has_processor(model_config):
tokenizer = cached_tokenizer_from_config(model_config)
processor = self.create_processor(model_config, tokenizer)
processor = self.create_processor(model_config,
tokenizer,
disable_cache=True)
profiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()
@ -433,6 +437,8 @@ class MultiModalRegistry:
self,
model_config: "ModelConfig",
tokenizer: AnyTokenizer,
*,
disable_cache: Optional[bool] = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
"""
Create a multi-modal processor for a specific model and tokenizer.
@ -440,11 +446,13 @@ class MultiModalRegistry:
See also:
:ref:`mm-processing`
"""
if disable_cache is None:
disable_cache = model_config.disable_mm_preprocessor_cache
model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls]
ctx = InputProcessingContext(model_config, tokenizer)
cache = (None if model_config.disable_mm_preprocessor_cache else
self._processing_cache)
cache = None if disable_cache else self._processing_cache
return factories.build_processor(ctx, cache=cache)