[Model] Add smolvlm support (#16017)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
1f4b09b525
commit
102bf967f0
@ -990,6 +990,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
*
|
*
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
- * `SmolVLMForConditionalGeneration`
|
||||||
|
* SmolVLM2
|
||||||
|
* T + I
|
||||||
|
* `SmolVLM2-2.2B-Instruct`
|
||||||
|
*
|
||||||
|
* ✅︎
|
||||||
|
* ✅︎
|
||||||
- * `UltravoxModel`
|
- * `UltravoxModel`
|
||||||
* Ultravox
|
* Ultravox
|
||||||
* T + A<sup>E+</sup>
|
* T + A<sup>E+</sup>
|
||||||
|
@ -298,6 +298,34 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# SmolVLM2-2.2B-Instruct
|
||||||
|
def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
|
||||||
|
assert modality == "image"
|
||||||
|
model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
max_model_len=8192,
|
||||||
|
max_num_seqs=2,
|
||||||
|
enforce_eager=True,
|
||||||
|
mm_processor_kwargs={
|
||||||
|
"max_image_size": {
|
||||||
|
"longest_edge": 384
|
||||||
|
},
|
||||||
|
},
|
||||||
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
|
)
|
||||||
|
prompts = [
|
||||||
|
(f"<|im_start|>User:<image>{question}<end_of_utterance>\nAssistant:")
|
||||||
|
for question in questions
|
||||||
|
]
|
||||||
|
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompts=prompts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# InternVL
|
# InternVL
|
||||||
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
@ -955,6 +983,7 @@ model_example_map = {
|
|||||||
"qwen2_vl": run_qwen2_vl,
|
"qwen2_vl": run_qwen2_vl,
|
||||||
"qwen2_5_vl": run_qwen2_5_vl,
|
"qwen2_5_vl": run_qwen2_5_vl,
|
||||||
"skywork_chat": run_skyworkr1v,
|
"skywork_chat": run_skyworkr1v,
|
||||||
|
"smolvlm": run_smolvlm,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -217,6 +217,33 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
|
model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
|
||||||
|
|
||||||
|
# The configuration below has been confirmed to launch on a single L40 GPU.
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
max_model_len=8192,
|
||||||
|
max_num_seqs=16,
|
||||||
|
enforce_eager=True,
|
||||||
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
|
mm_processor_kwargs={
|
||||||
|
"max_image_size": {
|
||||||
|
"longest_edge": 384
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||||
|
for i, _ in enumerate(image_urls, start=1))
|
||||||
|
prompt = f"<|im_start|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompt=prompt,
|
||||||
|
image_data=[fetch_image(url) for url in image_urls],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
|
def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "OpenGVLab/InternVL2-2B"
|
model_name = "OpenGVLab/InternVL2-2B"
|
||||||
|
|
||||||
@ -614,6 +641,7 @@ model_example_map = {
|
|||||||
"qwen_vl_chat": load_qwen_vl_chat,
|
"qwen_vl_chat": load_qwen_vl_chat,
|
||||||
"qwen2_vl": load_qwen2_vl,
|
"qwen2_vl": load_qwen2_vl,
|
||||||
"qwen2_5_vl": load_qwen2_5_vl,
|
"qwen2_5_vl": load_qwen2_5_vl,
|
||||||
|
"smolvlm": load_smolvlm,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ torchvision==0.21.0
|
|||||||
transformers_stream_generator # required for qwen-vl test
|
transformers_stream_generator # required for qwen-vl test
|
||||||
matplotlib # required for qwen-vl test
|
matplotlib # required for qwen-vl test
|
||||||
mistral_common[opencv] >= 1.5.4 # required for pixtral test
|
mistral_common[opencv] >= 1.5.4 # required for pixtral test
|
||||||
|
num2words # required for smolvlm test
|
||||||
opencv-python-headless >= 4.11.0 # required for video test
|
opencv-python-headless >= 4.11.0 # required for video test
|
||||||
datamodel_code_generator # required for minicpm3 test
|
datamodel_code_generator # required for minicpm3 test
|
||||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||||
|
@ -101,6 +101,8 @@ dill==0.3.8
|
|||||||
# multiprocess
|
# multiprocess
|
||||||
dnspython==2.7.0
|
dnspython==2.7.0
|
||||||
# via email-validator
|
# via email-validator
|
||||||
|
docopt==0.6.2
|
||||||
|
# via num2words
|
||||||
docutils==0.16
|
docutils==0.16
|
||||||
# via awscli
|
# via awscli
|
||||||
einops==0.8.0
|
einops==0.8.0
|
||||||
@ -263,6 +265,8 @@ networkx==3.2.1
|
|||||||
# via torch
|
# via torch
|
||||||
nltk==3.9.1
|
nltk==3.9.1
|
||||||
# via rouge-score
|
# via rouge-score
|
||||||
|
num2words==0.5.14
|
||||||
|
# via -r requirements/test.in
|
||||||
numba==0.61.0
|
numba==0.61.0
|
||||||
# via
|
# via
|
||||||
# -r requirements/test.in
|
# -r requirements/test.in
|
||||||
|
@ -493,6 +493,16 @@ VLM_TEST_SETTINGS = {
|
|||||||
patch_hf_runner=model_utils.skyworkr1v_patch_hf_runner,
|
patch_hf_runner=model_utils.skyworkr1v_patch_hf_runner,
|
||||||
marks=[large_gpu_mark(min_gb=80)],
|
marks=[large_gpu_mark(min_gb=80)],
|
||||||
),
|
),
|
||||||
|
"smolvlm": VLMTestInfo(
|
||||||
|
models=["HuggingFaceTB/SmolVLM2-2.2B-Instruct"],
|
||||||
|
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||||
|
prompt_formatter=lambda img_prompt:f"<|im_start|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
|
||||||
|
img_idx_to_prompt=lambda idx: "<image>",
|
||||||
|
max_model_len=8192,
|
||||||
|
max_num_seqs=2,
|
||||||
|
auto_cls=AutoModelForImageTextToText,
|
||||||
|
hf_output_post_proc=model_utils.smolvlm_trunc_hf_output,
|
||||||
|
),
|
||||||
### Tensor parallel / multi-gpu broadcast tests
|
### Tensor parallel / multi-gpu broadcast tests
|
||||||
"chameleon-broadcast": VLMTestInfo(
|
"chameleon-broadcast": VLMTestInfo(
|
||||||
models=["facebook/chameleon-7b"],
|
models=["facebook/chameleon-7b"],
|
||||||
|
@ -204,6 +204,12 @@ def idefics3_trunc_hf_output(hf_output: RunnerOutput,
|
|||||||
return output_ids, output_str, out_logprobs
|
return output_ids, output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def smolvlm_trunc_hf_output(hf_output: RunnerOutput,
|
||||||
|
model: str) -> RunnerOutput:
|
||||||
|
# Based on Idefics3
|
||||||
|
return idefics3_trunc_hf_output(hf_output, model)
|
||||||
|
|
||||||
|
|
||||||
def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
|
def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
|
||||||
model: str) -> RunnerOutput:
|
model: str) -> RunnerOutput:
|
||||||
output_ids, output_str, out_logprobs = hf_output
|
output_ids, output_str, out_logprobs = hf_output
|
||||||
|
@ -257,6 +257,7 @@ def _test_processing_correctness_mistral(
|
|||||||
"h2oai/h2ovl-mississippi-800m",
|
"h2oai/h2ovl-mississippi-800m",
|
||||||
"OpenGVLab/InternVL2-1B",
|
"OpenGVLab/InternVL2-1B",
|
||||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||||
|
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
||||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
"llava-hf/llava-1.5-7b-hf",
|
"llava-hf/llava-1.5-7b-hf",
|
||||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||||
|
65
tests/models/multimodal/processing/test_smolvlm.py
Normal file
65
tests/models/multimodal/processing/test_smolvlm.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Tests for smolvlm's multimodal preprocessing kwargs."""
|
||||||
|
import pytest
|
||||||
|
from transformers import SmolVLMConfig
|
||||||
|
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
|
from ....conftest import _ImageAssets
|
||||||
|
from ...utils import build_model_context
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_id", ["HuggingFaceTB/SmolVLM2-2.2B-Instruct"])
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("mm_processor_kwargs", "expected_toks_per_img"),
|
||||||
|
[
|
||||||
|
({"max_image_size": {"longest_edge": 384}}, 1377),
|
||||||
|
({"max_image_size": {"longest_edge": 768}}, 405),
|
||||||
|
])
|
||||||
|
# yapf: enable
|
||||||
|
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||||
|
@pytest.mark.parametrize("kwargs_on_init", [True, False])
|
||||||
|
def test_processor_override(
|
||||||
|
image_assets: _ImageAssets,
|
||||||
|
model_id: str,
|
||||||
|
mm_processor_kwargs: dict[str, object],
|
||||||
|
expected_toks_per_img: int,
|
||||||
|
num_imgs: int,
|
||||||
|
kwargs_on_init: bool,
|
||||||
|
):
|
||||||
|
"""Ensure Idefics3MultiModalProcessor handles num_crops properly."""
|
||||||
|
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||||
|
# in this test and assume that the kwargs will be correctly expanded by
|
||||||
|
# the partial when calling the custom input processor.
|
||||||
|
ctx = build_model_context(
|
||||||
|
model_id,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
|
||||||
|
limit_mm_per_prompt={"image": num_imgs},
|
||||||
|
)
|
||||||
|
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||||
|
hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
|
||||||
|
|
||||||
|
# Build the image str / prompt based on the number of images we pass
|
||||||
|
placeholders = "<image>" if num_imgs == 1 else "\n".join(
|
||||||
|
f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1))
|
||||||
|
prompt = f"<|im_start|>User:{placeholders}\n<end_of_utterance>\nAssistant:" # noqa: E501
|
||||||
|
|
||||||
|
# Build mm_data
|
||||||
|
image_size = ctx.get_hf_config(SmolVLMConfig).vision_config.image_size
|
||||||
|
dummy_image_size = (image_size * 4, image_size * 4)
|
||||||
|
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
|
||||||
|
mm_data = {"image": [dummy_image] * num_imgs}
|
||||||
|
|
||||||
|
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||||
|
|
||||||
|
# Ensure the placeholders format are correct
|
||||||
|
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"])
|
||||||
|
assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[
|
||||||
|
"input_ids"][0]
|
||||||
|
|
||||||
|
# Ensure we have the right number of placeholders per num_crops size
|
||||||
|
image_token_id = ctx.get_hf_config().image_token_id
|
||||||
|
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
|
||||||
|
assert img_tok_count == expected_toks_per_img * num_imgs
|
@ -344,6 +344,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
||||||
min_transformers_version="4.49"), # noqa: E501
|
min_transformers_version="4.49"), # noqa: E501
|
||||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
|
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
|
||||||
|
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
|
||||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_transformers_version="4.50"),
|
max_transformers_version="4.50"),
|
||||||
|
@ -498,7 +498,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
hf_config.image_token_index)
|
hf_config.image_token_index)
|
||||||
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
|
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
|
||||||
"internvl_chat", "skywork_chat", "NVLM_D",
|
"internvl_chat", "skywork_chat", "NVLM_D",
|
||||||
"h2ovl_chat", "idefics3"):
|
"h2ovl_chat", "idefics3", "smolvlm"):
|
||||||
return "<image>"
|
return "<image>"
|
||||||
if model_type in ("mllama", "llama4"):
|
if model_type in ("mllama", "llama4"):
|
||||||
return "<|image|>"
|
return "<|image|>"
|
||||||
|
@ -206,6 +206,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
|
|||||||
|
|
||||||
return grid_w * grid_h + 1
|
return grid_w * grid_h + 1
|
||||||
|
|
||||||
|
def _get_image_token(
|
||||||
|
self,
|
||||||
|
processor: Optional[Idefics3Processor]) -> tuple[str, str, str]:
|
||||||
|
if processor is None:
|
||||||
|
processor = self.get_hf_processor()
|
||||||
|
image_token = processor.image_token.content
|
||||||
|
fake_image_token = processor.fake_image_token.content
|
||||||
|
global_image_token = processor.global_image_tag
|
||||||
|
return image_token, fake_image_token, global_image_token
|
||||||
|
|
||||||
def get_image_repl(
|
def get_image_repl(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -216,9 +226,8 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
|
|||||||
if processor is None:
|
if processor is None:
|
||||||
processor = self.get_hf_processor()
|
processor = self.get_hf_processor()
|
||||||
|
|
||||||
image_token = processor.image_token.content
|
image_token, fake_image_token, global_img_token = self._get_image_token(
|
||||||
fake_image_token = processor.fake_image_token.content
|
processor)
|
||||||
global_img_token = processor.global_image_tag
|
|
||||||
image_seq_len = processor.image_seq_len
|
image_seq_len = processor.image_seq_len
|
||||||
grid_placeholder = "<row_{n_h}_col_{n_w}>"
|
grid_placeholder = "<row_{n_h}_col_{n_w}>"
|
||||||
|
|
||||||
@ -300,7 +309,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
|
|||||||
hf_processor = self.info.get_hf_processor()
|
hf_processor = self.info.get_hf_processor()
|
||||||
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
||||||
longest_edge = image_processor.max_image_size['longest_edge']
|
longest_edge = image_processor.max_image_size['longest_edge']
|
||||||
image_token = hf_processor.image_token.content
|
image_token, _, _ = self.info._get_image_token(hf_processor)
|
||||||
|
|
||||||
mm_data = {
|
mm_data = {
|
||||||
"image":
|
"image":
|
||||||
@ -382,7 +391,7 @@ class Idefics3MultiModalProcessor(
|
|||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> Sequence[PromptUpdate]:
|
) -> Sequence[PromptUpdate]:
|
||||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
image_token = hf_processor.image_token.content
|
image_token, _, _ = self.info._get_image_token(hf_processor)
|
||||||
|
|
||||||
def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
|
def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
|
||||||
images = mm_items.get_items("image", ImageProcessorItems)
|
images = mm_items.get_items("image", ImageProcessorItems)
|
||||||
|
@ -175,6 +175,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
||||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||||
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
||||||
|
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
|
||||||
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
||||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||||
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
||||||
|
51
vllm/model_executor/models/smolvlm.py
Normal file
51
vllm/model_executor/models/smolvlm.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from transformers import SmolVLMProcessor
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
from .idefics3 import Idefics3DummyInputsBuilder as SmolVLMDummyInputsBuilder
|
||||||
|
from .idefics3 import Idefics3ForConditionalGeneration
|
||||||
|
from .idefics3 import Idefics3MultiModalProcessor as SmolVLMMultiModalProcessor
|
||||||
|
from .idefics3 import Idefics3ProcessingInfo
|
||||||
|
|
||||||
|
# yapf: enable
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMProcessingInfo(Idefics3ProcessingInfo):
|
||||||
|
|
||||||
|
def get_hf_processor(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
max_image_size: Optional[Dict[str, int]] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> SmolVLMProcessor:
|
||||||
|
if max_image_size is not None:
|
||||||
|
kwargs["max_image_size"] = max_image_size
|
||||||
|
|
||||||
|
return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs)
|
||||||
|
|
||||||
|
def _get_image_token(
|
||||||
|
self, processor: Optional[SmolVLMProcessor]) -> tuple[str, str]:
|
||||||
|
if processor is None:
|
||||||
|
processor = self.get_hf_processor()
|
||||||
|
image_token = processor.image_token
|
||||||
|
fake_image_token = processor.fake_image_token
|
||||||
|
global_image_token = processor.global_image_token
|
||||||
|
return image_token, fake_image_token, global_image_token
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(SmolVLMMultiModalProcessor,
|
||||||
|
info=SmolVLMProcessingInfo,
|
||||||
|
dummy_inputs=SmolVLMDummyInputsBuilder)
|
||||||
|
class SmolVLMForConditionalGeneration(Idefics3ForConditionalGeneration):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=prefix,
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user