[Model] Support Skywork-R1V (#15397)

Signed-off-by: jiacai.liu <932997367@qq.com>
Co-authored-by: jiacai.liu <932997367@qq.com>
This commit is contained in:
pengyuange 2025-03-29 11:39:21 +08:00 committed by GitHub
parent c802f5430d
commit de1cb38769
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1194 additions and 7 deletions

View File

@ -921,6 +921,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
- * `SkyworkR1VChatModel`
* Skywork-R1V-38B
* T + I
* `Skywork/Skywork-R1V-38B`
*
* ✅︎
* ✅︎
- * `UltravoxModel`
* Ultravox
* T + A<sup>E+</sup>

View File

@ -804,6 +804,41 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
)
# SkyworkR1V
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "Skywork/Skywork-R1V-38B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [[{
'role': 'user',
'content': f"<image>\n{question}"
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
# Stop tokens for SkyworkR1V
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py
stop_tokens = ["<end▁of▁sentence>", "<|endoftext|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
stop_token_ids=stop_token_ids,
)
model_example_map = {
"aria": run_aria,
"blip-2": run_blip2,
@ -834,6 +869,7 @@ model_example_map = {
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"qwen2_5_vl": run_qwen2_5_vl,
"skywork_chat": run_skyworkr1v,
}

View File

@ -474,6 +474,20 @@ VLM_TEST_SETTINGS = {
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
),
"skywork_r1v": VLMTestInfo(
models=["Skywork/Skywork-R1V-38B"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<begin▁of▁sentence><User>\n{img_prompt}<Assistant><think>\n", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501
"cherry_blossom": "<image>\nWhat is the season?",
}),
multi_image_prompt="<image>\n<image>\nDescribe the two images in short.", # noqa: E501
max_model_len=4096,
use_tokenizer_eos=True,
patch_hf_runner=model_utils.skyworkr1v_patch_hf_runner,
marks=[large_gpu_mark(min_gb=80)],
),
### Tensor parallel / multi-gpu broadcast tests
"chameleon-broadcast": VLMTestInfo(
models=["facebook/chameleon-7b"],

View File

@ -376,6 +376,63 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model
def skyworkr1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for SkyworkR1V."""
class SkyworkR1VProcessor:
"""A simple processor for SkyworkR1V."""
def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
trust_remote_code=True)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size
def __call__(self, text: str, images: Union[Image, list[Image]],
**kwargs):
from vllm.model_executor.models.skyworkr1v import (
IMG_CONTEXT, IMG_END, IMG_START,
image_to_pixel_values_skyworkr1v)
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values_skyworkr1v(
image,
input_size=self.image_size,
min_num=self.min_num,
max_num=self.max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
]
pixel_values = torch.cat(pixel_values, dim=0)
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt
img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
"<IMG_CONTEXT>")
hf_model.model.img_context_token_id = img_context_token_id
hf_model.processor = SkyworkR1VProcessor(hf_model)
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.get_output_embeddings()
hf_model.model.generate = types.MethodType(_internvl_generate,
hf_model.model)
return hf_model
def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for InternVL."""

View File

@ -262,22 +262,23 @@ def _test_processing_correctness_mistral(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-Llama3-V-2_5",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"Skywork/Skywork-R1V-38B",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])

View File

@ -294,6 +294,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
min_transformers_version="4.49"), # noqa: E501
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
trust_remote_code=True),
# [Encoder-decoder]

View File

@ -496,7 +496,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
"NVLM_D", "h2ovl_chat"):
"skywork_chat", "NVLM_D", "h2ovl_chat"):
return "<image>"
if model_type == "mllama":
return "<|image|>"

View File

@ -190,6 +190,7 @@ _MULTIMODAL_MODELS = {
# [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
}

File diff suppressed because it is too large Load Diff

View File

@ -37,8 +37,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, NVLM_D_Config,
Olmo2Config, RWConfig,
SolarConfig, Telechat2Config,
UltravoxConfig)
SkyworkR1VChatConfig, SolarConfig,
Telechat2Config, UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
@ -76,6 +76,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"NVLM_D": NVLM_D_Config,
"olmo2": Olmo2Config,
"solar": SolarConfig,
"skywork_chat": SkyworkR1VChatConfig,
"telechat": Telechat2Config,
"ultravox": UltravoxConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF

View File

@ -20,6 +20,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.olmo2 import Olmo2Config
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
@ -42,6 +43,7 @@ __all__ = [
"NemotronConfig",
"NVLM_D_Config",
"Olmo2Config",
"SkyworkR1VChatConfig",
"SolarConfig",
"Telechat2Config",
"UltravoxConfig",

View File

@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/configuration_skywork_chat.py
# --------------------------------------------------------
# SkyworkR1V
# Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from transformers.configuration_utils import PretrainedConfig
class SkyworkR1VChatConfig(PretrainedConfig):
model_type = 'internvl_chat'
is_composition = True
def __init__(self,
vision_config=None,
llm_config=None,
use_backbone_lora=0,
use_llm_lora=0,
select_layer=-1,
force_image_size=None,
downsample_ratio=0.5,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
ps_version='v1',
min_dynamic_patch=1,
max_dynamic_patch=6,
**kwargs):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {}
if llm_config is None:
llm_config = {}
self.vision_config = PretrainedConfig(**vision_config)
self.text_config = PretrainedConfig(**llm_config)
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.select_layer = select_layer
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.ps_version = ps_version # pixel shuffle version
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch