[Model] Update multi-modal processor to support Mantis(LLaVA) model (#10711)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-08 01:10:05 +08:00 committed by GitHub
parent 1c768fe537
commit 39e227c7ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 175 additions and 78 deletions

View File

@ -362,6 +362,7 @@ steps:
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/vision_language
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model
@ -377,6 +378,7 @@ steps:
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/vision_language
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
# HACK - run phi3v tests separately to sidestep this transformers bug
# https://github.com/huggingface/transformers/issues/34307

View File

@ -555,7 +555,7 @@ Text Generation
* - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5
- T + I\ :sup:`E+`
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc.
-
- ✅︎
* - :code:`LlavaNextForConditionalGeneration`
@ -664,6 +664,10 @@ Text Generation
.. note::
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
.. note::
To use :code:`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo (:code:`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`)
and pass :code:`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
.. note::
The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630

View File

@ -419,6 +419,22 @@ def run_aria(question: str, modality: str):
return llm, prompt, stop_token_ids
# Mantis
def run_mantis(question: str, modality: str):
assert modality == "image"
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501
prompt = llama3_template.format(f"{question}\n<image>")
llm = LLM(
model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
)
stop_token_ids = [128009]
return llm, prompt, stop_token_ids
model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
@ -441,6 +457,7 @@ model_example_map = {
"glm4v": run_glm4v,
"idefics3": run_idefics3,
"aria": run_aria,
"mantis": run_mantis,
}

View File

@ -24,9 +24,6 @@ mistral_common[opencv] >= 1.5.0 # required for pixtral test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.4 # required for model evaluation test
# TODO: Add this after fully implementing llava(mantis)
# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test
# quantization
bitsandbytes>=0.44.0
buildkite-test-collector==0.1.9

View File

@ -34,7 +34,7 @@ COMMON_BROADCAST_SETTINGS = {
"dtype": "half",
"max_tokens": 5,
"tensor_parallel_size": 2,
"model_kwargs": {"device_map": "auto"},
"hf_model_kwargs": {"device_map": "auto"},
"image_size_factors": [(.25, 0.5, 1.0)],
"distributed_executor_backend": (
"ray",
@ -108,7 +108,7 @@ VLM_TEST_SETTINGS = {
"cherry_blossom": "What is in the picture?",
}),
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
@ -151,7 +151,7 @@ VLM_TEST_SETTINGS = {
"cherry_blossom": "<vlm_image>Please infer the season with reason.",
}),
multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
postprocess_inputs=model_utils.get_key_type_post_processor("pixel_values"),
postprocess_inputs=model_utils.cast_dtype_post_processor("pixel_values"),
stop_str=["<|im_end|>"],
image_size_factors=[(0.10, 0.15)],
max_tokens=64,
@ -177,7 +177,7 @@ VLM_TEST_SETTINGS = {
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096,
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
# For chameleon, we only compare the sequences
@ -281,7 +281,7 @@ VLM_TEST_SETTINGS = {
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
num_video_frames=16,
max_model_len=16384,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values_videos"
),
auto_cls=AutoModelForVision2Seq,
@ -306,6 +306,20 @@ VLM_TEST_SETTINGS = {
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
),
"mantis": VLMTestInfo(
models=["TIGER-Lab/Mantis-8B-siglip-llama3"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
max_model_len=4096,
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_runner_kwargs={"hf_overrides": {"architectures": ["MantisForConditionalGeneration"]}}, # noqa: E501
get_stop_token_ids=lambda tok: [128009],
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output,
patch_hf_runner=model_utils.mantis_patch_hf_runner,
),
"minicpmv_25": VLMTestInfo(
models=["openbmb/MiniCPM-Llama3-V-2_5"],
test_type=VLMTestType.IMAGE,
@ -342,7 +356,7 @@ VLM_TEST_SETTINGS = {
# max_num_seqs=2,
# task="generate",
# # use eager mode for hf runner since phi3v didn't work with flash_attn
# model_kwargs={"_attn_implementation": "eager"},
# hf_model_kwargs={"_attn_implementation": "eager"},
# use_tokenizer_eos=True,
# vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output,
# num_logprobs=10,
@ -373,7 +387,7 @@ VLM_TEST_SETTINGS = {
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096,
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2],
@ -438,7 +452,7 @@ VLM_TEST_SETTINGS = {
test_type=VLMTestType.CUSTOM_INPUTS,
max_model_len=16384,
max_num_seqs=2,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
auto_cls=AutoModelForVision2Seq,

View File

@ -3,9 +3,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import torch
from PIL.Image import Image
from transformers import AutoTokenizer, BatchEncoding
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config import TaskOption
from .....conftest import HfRunner, VllmRunner
from .types import RunnerOutput
@ -28,13 +30,15 @@ def run_test(
use_tokenizer_eos: bool,
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
comparator: Callable[..., None],
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]],
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
List[int]]],
stop_str: Optional[List[str]],
tokenizer_mode: str,
limit_mm_per_prompt: Dict[str, int],
model_kwargs: Optional[Dict[str, Any]],
vllm_runner_kwargs: Optional[Dict[str, Any]],
hf_model_kwargs: Optional[Dict[str, Any]],
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
task: str = "auto",
task: TaskOption = "auto",
runner_mm_key: str = "images",
distributed_executor_backend: Optional[str] = None,
tensor_parallel_size: int = 1,
@ -58,6 +62,9 @@ def run_test(
if stop_str:
vllm_kwargs["stop"] = stop_str
if vllm_runner_kwargs is None:
vllm_runner_kwargs = {}
with vllm_runner(model,
tokenizer_mode=tokenizer_mode,
max_model_len=max_model_len,
@ -67,7 +74,8 @@ def run_test(
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=enforce_eager,
task=task) as vllm_model:
task=task,
**vllm_runner_kwargs) as vllm_model:
for prompts, media in vllm_inputs:
vllm_kwargs[runner_mm_key] = media
vllm_output = vllm_model.generate_greedy_logprobs(
@ -78,7 +86,7 @@ def run_test(
dtype=dtype,
auto_cls=auto_cls,
postprocess_inputs=postprocess_inputs,
model_kwargs=model_kwargs)
model_kwargs=hf_model_kwargs)
# Some models need to patch things like the model processor, e.g., internvl
if patch_hf_runner is not None:

View File

@ -126,6 +126,16 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput,
return hf_output_ids, hf_output_str, out_logprobs
def mantis_vllm_to_hf_output(vllm_output: RunnerOutput,
model: str) -> RunnerOutput:
"""Sanitize vllm output [mantis] to compare with hf output."""
output_ids, output_str, out_logprobs = vllm_output
hf_output_str = output_str + "<|eot_id|>"
return output_ids, hf_output_str, out_logprobs
def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput,
model: str) -> RunnerOutput:
"""Sanitize vllm output [phi3v] to be comparable with hf output."""
@ -184,7 +194,7 @@ def get_llava_embeddings(image_assets: _ImageAssets):
####### postprocessors to run on HF BatchEncoding
def get_key_type_post_processor(
def cast_dtype_post_processor(
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
"""Gets a handle to a post processor which converts a given key into a
target data type."""
@ -418,3 +428,26 @@ def _internvl_generate(
)
return outputs
def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
from mantis.models.mllava import MLlavaProcessor
hf_model.processor = MLlavaProcessor.from_pretrained(hf_model.model_name)
orig_generate = hf_model.model.generate
tokenizer = hf_model.processor.tokenizer
def _generate(self, *args, **kwargs):
return orig_generate(
*args,
**kwargs,
eos_token_id=[
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
],
)
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
return hf_model

View File

@ -7,9 +7,11 @@ from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Optional,
import torch
from PIL.Image import Image
from pytest import MarkDecorator
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding
from transformers import (AutoModelForCausalLM, BatchEncoding,
PreTrainedTokenizerBase)
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config import TaskOption
from vllm.sequence import SampleLogprobs
from vllm.utils import identity
@ -66,7 +68,7 @@ class ImageSizeWrapper(NamedTuple):
class VLMTestInfo(NamedTuple):
"""Holds the configuration for 1+ tests for one model architecture."""
models: Union[List[str]]
models: List[str]
test_type: Union[VLMTestType, Iterable[VLMTestType]]
# Should be None only if this is a CUSTOM_INPUTS test
@ -92,18 +94,20 @@ class VLMTestInfo(NamedTuple):
enforce_eager: bool = True
max_model_len: int = 1024
max_num_seqs: int = 256
task: str = "auto"
task: TaskOption = "auto"
tensor_parallel_size: int = 1
vllm_runner_kwargs: Optional[Dict[str, Any]] = None
# Optional callable which gets a list of token IDs from the model tokenizer
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
List[int]]] = None
# Optional list of strings to stop generation, useful when stop tokens are
# not special tokens in the tokenizer
stop_str: Optional[List[str]] = None
# Exposed options for HF runner
model_kwargs: Optional[Dict[str, Any]] = None
# Indicates we should explicitly pass the EOS from the tokeniezr
hf_model_kwargs: Optional[Dict[str, Any]] = None
# Indicates we should explicitly pass the EOS from the tokenizer
use_tokenizer_eos: bool = False
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM
# Callable to pass to the HF runner to run on inputs; for now, we also pass
@ -164,6 +168,7 @@ class VLMTestInfo(NamedTuple):
"max_num_seqs": self.max_num_seqs,
"task": self.task,
"tensor_parallel_size": self.tensor_parallel_size,
"vllm_runner_kwargs": self.vllm_runner_kwargs,
"hf_output_post_proc": self.hf_output_post_proc,
"vllm_output_post_proc": self.vllm_output_post_proc,
"auto_cls": self.auto_cls,
@ -171,8 +176,8 @@ class VLMTestInfo(NamedTuple):
"postprocess_inputs": self.postprocess_inputs,
"comparator": self.comparator,
"get_stop_token_ids": self.get_stop_token_ids,
"hf_model_kwargs": self.hf_model_kwargs,
"stop_str": self.stop_str,
"model_kwargs": self.model_kwargs,
"patch_hf_runner": self.patch_hf_runner,
"tokenizer_mode": self.tokenizer_mode
}

View File

@ -176,6 +176,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",

View File

@ -3,16 +3,14 @@ from typing import Optional
import torch
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
create_metadata_for_llava,
dummy_mm_kwargs_for_llava,
LlavaProcessor,
get_max_llava_image_tokens)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava,
dummy_mm_kwargs_for_llava)
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
class MyLlava(LlavaForConditionalGeneration):
def compute_logits(

View File

@ -22,10 +22,11 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.processing import (InputProcessingContext,
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ModalityProcessingMetadata,
MultiModalProcessingMetadata,
MultiModalProcessor, PromptReplacement)
PromptReplacement)
from vllm.sequence import IntermediateTensors
from .clip import (CLIPVisionModel, dummy_image_for_clip,
@ -163,7 +164,13 @@ def create_metadata_for_llava(
}
class LlavaProcessor(MultiModalProcessor):
class LlavaProcessor(BaseMultiModalProcessor):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(
ctx=ctx,
metadata=create_metadata_for_llava(ctx),
)
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
if getattr(hf_processor, "__is_patched__", False):
@ -193,7 +200,30 @@ class LlavaProcessor(MultiModalProcessor):
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
return dummy_mm_kwargs_for_llava(self.ctx, mm_counts)
hf_config = self.ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
if isinstance(vision_config, CLIPVisionConfig):
data = dummy_image_for_clip(vision_config, num_images)
elif isinstance(vision_config, SiglipVisionConfig):
data = dummy_image_for_siglip(vision_config, num_images)
elif isinstance(vision_config, PixtralVisionConfig):
data = dummy_image_for_pixtral_hf(vision_config, num_images)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
hf_processor = self._get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'],
return_tensors="pt")
is_pixtral = isinstance(hf_processor, PixtralProcessor)
return MultiModalKwargs(
**hf_inputs,
is_pixtral=torch.tensor(is_pixtral),
)
class LlavaLikeConfig(Protocol):
@ -277,10 +307,7 @@ def init_vision_tower_for_llava(
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor(
ctx=ctx,
metadata=create_metadata_for_llava(ctx),
))
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
@ -559,3 +586,28 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
class MantisProcessor(LlavaProcessor):
def _get_hf_processor(self) -> ProcessorMixin:
try:
from mantis.models.mllava import MLlavaProcessor
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"You need to `pip install "
"git+https://github.com/TIGER-AI-Lab/Mantis.git` "
"to use this model") from exc
processor = MLlavaProcessor.from_pretrained(
self.ctx.model_config.tokenizer)
assert isinstance(processor, ProcessorMixin)
return processor
# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(MantisProcessor)
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass

View File

@ -152,6 +152,7 @@ _MULTIMODAL_MODELS = {
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
"MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),

View File

@ -529,9 +529,9 @@ def iter_placeholders(
yield placeholder
class MultiModalProcessor(ABC):
class BaseMultiModalProcessor(ABC):
"""
Helper class to process multi-modal inputs to be used in vLLM.
Abstract base class to process multi-modal inputs to be used in vLLM.
"""
def __init__(

View File

@ -15,7 +15,7 @@ from .audio import AudioPlugin
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import MultiModalProcessingMetadata, MultiModalProcessor
from .processing import BaseMultiModalProcessor
from .video import VideoPlugin
if TYPE_CHECKING:
@ -26,7 +26,7 @@ logger = init_logger(__name__)
N = TypeVar("N", bound=Type[nn.Module])
MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext],
MultiModalProcessor]
BaseMultiModalProcessor]
"""
Constructs a :class:`MultiModalProcessor` instance from the context.
@ -311,41 +311,6 @@ class MultiModalRegistry:
return wrapper
def register_processor_by_metadata(
self,
metadata_factory: Callable[[InputProcessingContext],
MultiModalProcessingMetadata],
get_dummy_mm_kwargs: Callable[
[InputProcessingContext, Mapping[str, int]], MultiModalKwargs],
):
"""
Convenience method to register a multi-modal processor to a model class
according to a function that constructs its metadata.
When the model receives multi-modal data, the provided function is
invoked to transform the data into a dictionary of model inputs.
See also:
- :ref:`input_processing_pipeline`
- :ref:`enabling_multimodal_inputs`
"""
class ConcreteMultiModalProcessor(MultiModalProcessor):
def _get_dummy_mm_kwargs(
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
return get_dummy_mm_kwargs(self.ctx, mm_counts)
def factory(ctx: InputProcessingContext):
return ConcreteMultiModalProcessor(
ctx=ctx,
metadata=metadata_factory(ctx),
)
return self.register_processor(factory)
def has_processor(self, model_config: "ModelConfig") -> bool:
"""
Test whether a multi-modal processor is defined for a specific model.
@ -360,7 +325,7 @@ class MultiModalRegistry:
self,
model_config: "ModelConfig",
tokenizer: AnyTokenizer,
) -> MultiModalProcessor:
) -> BaseMultiModalProcessor:
"""
Create a multi-modal processor for a specific model and tokenizer.
"""