[Model] Add smolvlm support (#16017)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey 2025-04-09 10:12:17 +08:00 committed by GitHub
parent 1f4b09b525
commit 102bf967f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 219 additions and 6 deletions

View File

@ -990,6 +990,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
- * `SmolVLMForConditionalGeneration`
* SmolVLM2
* T + I
* `SmolVLM2-2.2B-Instruct`
*
* ✅︎
* ✅︎
- * `UltravoxModel`
* Ultravox
* T + A<sup>E+</sup>

View File

@ -298,6 +298,34 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
)
# SmolVLM2-2.2B-Instruct
def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
enforce_eager=True,
mm_processor_kwargs={
"max_image_size": {
"longest_edge": 384
},
},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompts = [
(f"<|im_start|>User:<image>{question}<end_of_utterance>\nAssistant:")
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# InternVL
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@ -955,6 +983,7 @@ model_example_map = {
"qwen2_vl": run_qwen2_vl,
"qwen2_5_vl": run_qwen2_5_vl,
"skywork_chat": run_skyworkr1v,
"smolvlm": run_smolvlm,
}

View File

@ -217,6 +217,33 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
# The configuration below has been confirmed to launch on a single L40 GPU.
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=16,
enforce_eager=True,
limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={
"max_image_size": {
"longest_edge": 384
},
},
)
placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|im_start|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "OpenGVLab/InternVL2-2B"
@ -614,6 +641,7 @@ model_example_map = {
"qwen_vl_chat": load_qwen_vl_chat,
"qwen2_vl": load_qwen2_vl,
"qwen2_5_vl": load_qwen2_5_vl,
"smolvlm": load_smolvlm,
}

View File

@ -28,6 +28,7 @@ torchvision==0.21.0
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.5.4 # required for pixtral test
num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test

View File

@ -101,6 +101,8 @@ dill==0.3.8
# multiprocess
dnspython==2.7.0
# via email-validator
docopt==0.6.2
# via num2words
docutils==0.16
# via awscli
einops==0.8.0
@ -263,6 +265,8 @@ networkx==3.2.1
# via torch
nltk==3.9.1
# via rouge-score
num2words==0.5.14
# via -r requirements/test.in
numba==0.61.0
# via
# -r requirements/test.in

View File

@ -493,6 +493,16 @@ VLM_TEST_SETTINGS = {
patch_hf_runner=model_utils.skyworkr1v_patch_hf_runner,
marks=[large_gpu_mark(min_gb=80)],
),
"smolvlm": VLMTestInfo(
models=["HuggingFaceTB/SmolVLM2-2.2B-Instruct"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt:f"<|im_start|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
img_idx_to_prompt=lambda idx: "<image>",
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
hf_output_post_proc=model_utils.smolvlm_trunc_hf_output,
),
### Tensor parallel / multi-gpu broadcast tests
"chameleon-broadcast": VLMTestInfo(
models=["facebook/chameleon-7b"],

View File

@ -204,6 +204,12 @@ def idefics3_trunc_hf_output(hf_output: RunnerOutput,
return output_ids, output_str, out_logprobs
def smolvlm_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
# Based on Idefics3
return idefics3_trunc_hf_output(hf_output, model)
def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output

View File

@ -257,6 +257,7 @@ def _test_processing_correctness_mistral(
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
"HuggingFaceM4/Idefics3-8B-Llama3",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf",

View File

@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for smolvlm's multimodal preprocessing kwargs."""
import pytest
from transformers import SmolVLMConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id", ["HuggingFaceTB/SmolVLM2-2.2B-Instruct"])
# yapf: disable
@pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"),
[
({"max_image_size": {"longest_edge": 384}}, 1377),
({"max_image_size": {"longest_edge": 768}}, 405),
])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, object],
expected_toks_per_img: int,
num_imgs: int,
kwargs_on_init: bool,
):
"""Ensure Idefics3MultiModalProcessor handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
ctx = build_model_context(
model_id,
mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
limit_mm_per_prompt={"image": num_imgs},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
# Build the image str / prompt based on the number of images we pass
placeholders = "<image>" if num_imgs == 1 else "\n".join(
f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1))
prompt = f"<|im_start|>User:{placeholders}\n<end_of_utterance>\nAssistant:" # noqa: E501
# Build mm_data
image_size = ctx.get_hf_config(SmolVLMConfig).vision_config.image_size
dummy_image_size = (image_size * 4, image_size * 4)
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
# Ensure the placeholders format are correct
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"])
assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[
"input_ids"][0]
# Ensure we have the right number of placeholders per num_crops size
image_token_id = ctx.get_hf_config().image_token_id
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
assert img_tok_count == expected_toks_per_img * num_imgs

View File

@ -344,6 +344,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
min_transformers_version="4.49"), # noqa: E501
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
trust_remote_code=True,
max_transformers_version="4.50"),

View File

@ -498,7 +498,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config.image_token_index)
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
"internvl_chat", "skywork_chat", "NVLM_D",
"h2ovl_chat", "idefics3"):
"h2ovl_chat", "idefics3", "smolvlm"):
return "<image>"
if model_type in ("mllama", "llama4"):
return "<|image|>"

View File

@ -206,6 +206,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
return grid_w * grid_h + 1
def _get_image_token(
self,
processor: Optional[Idefics3Processor]) -> tuple[str, str, str]:
if processor is None:
processor = self.get_hf_processor()
image_token = processor.image_token.content
fake_image_token = processor.fake_image_token.content
global_image_token = processor.global_image_tag
return image_token, fake_image_token, global_image_token
def get_image_repl(
self,
*,
@ -216,9 +226,8 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
if processor is None:
processor = self.get_hf_processor()
image_token = processor.image_token.content
fake_image_token = processor.fake_image_token.content
global_img_token = processor.global_image_tag
image_token, fake_image_token, global_img_token = self._get_image_token(
processor)
image_seq_len = processor.image_seq_len
grid_placeholder = "<row_{n_h}_col_{n_w}>"
@ -300,7 +309,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
hf_processor = self.info.get_hf_processor()
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
longest_edge = image_processor.max_image_size['longest_edge']
image_token = hf_processor.image_token.content
image_token, _, _ = self.info._get_image_token(hf_processor)
mm_data = {
"image":
@ -382,7 +391,7 @@ class Idefics3MultiModalProcessor(
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content
image_token, _, _ = self.info._get_image_token(hf_processor)
def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
images = mm_items.get_items("image", ImageProcessorItems)

View File

@ -175,6 +175,7 @@ _MULTIMODAL_MODELS = {
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501

View File

@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, Optional
from transformers import SmolVLMProcessor
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
# yapf: disable
from .idefics3 import Idefics3DummyInputsBuilder as SmolVLMDummyInputsBuilder
from .idefics3 import Idefics3ForConditionalGeneration
from .idefics3 import Idefics3MultiModalProcessor as SmolVLMMultiModalProcessor
from .idefics3 import Idefics3ProcessingInfo
# yapf: enable
class SmolVLMProcessingInfo(Idefics3ProcessingInfo):
def get_hf_processor(
self,
*,
max_image_size: Optional[Dict[str, int]] = None,
**kwargs: object,
) -> SmolVLMProcessor:
if max_image_size is not None:
kwargs["max_image_size"] = max_image_size
return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs)
def _get_image_token(
self, processor: Optional[SmolVLMProcessor]) -> tuple[str, str]:
if processor is None:
processor = self.get_hf_processor()
image_token = processor.image_token
fake_image_token = processor.fake_image_token
global_image_token = processor.global_image_token
return image_token, fake_image_token, global_image_token
@MULTIMODAL_REGISTRY.register_processor(SmolVLMMultiModalProcessor,
info=SmolVLMProcessingInfo,
dummy_inputs=SmolVLMDummyInputsBuilder)
class SmolVLMForConditionalGeneration(Idefics3ForConditionalGeneration):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
)