[Model] Support Skywork-R1V (#15397)
Signed-off-by: jiacai.liu <932997367@qq.com> Co-authored-by: jiacai.liu <932997367@qq.com>
This commit is contained in:
parent
c802f5430d
commit
de1cb38769
@ -921,6 +921,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `SkyworkR1VChatModel`
|
||||
* Skywork-R1V-38B
|
||||
* T + I
|
||||
* `Skywork/Skywork-R1V-38B`
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `UltravoxModel`
|
||||
* Ultravox
|
||||
* T + A<sup>E+</sup>
|
||||
|
@ -804,6 +804,41 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# SkyworkR1V
|
||||
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "Skywork/Skywork-R1V-38B"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
messages = [[{
|
||||
'role': 'user',
|
||||
'content': f"<image>\n{question}"
|
||||
}] for question in questions]
|
||||
prompts = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
# Stop tokens for SkyworkR1V
|
||||
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py
|
||||
stop_tokens = ["<|end▁of▁sentence|>", "<|endoftext|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
stop_token_ids=stop_token_ids,
|
||||
)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"aria": run_aria,
|
||||
"blip-2": run_blip2,
|
||||
@ -834,6 +869,7 @@ model_example_map = {
|
||||
"qwen_vl": run_qwen_vl,
|
||||
"qwen2_vl": run_qwen2_vl,
|
||||
"qwen2_5_vl": run_qwen2_5_vl,
|
||||
"skywork_chat": run_skyworkr1v,
|
||||
}
|
||||
|
||||
|
||||
|
@ -474,6 +474,20 @@ VLM_TEST_SETTINGS = {
|
||||
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
|
||||
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
|
||||
),
|
||||
"skywork_r1v": VLMTestInfo(
|
||||
models=["Skywork/Skywork-R1V-38B"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin▁of▁sentence|><|User|>\n{img_prompt}<|Assistant|><think>\n", # noqa: E501
|
||||
single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
"stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501
|
||||
"cherry_blossom": "<image>\nWhat is the season?",
|
||||
}),
|
||||
multi_image_prompt="<image>\n<image>\nDescribe the two images in short.", # noqa: E501
|
||||
max_model_len=4096,
|
||||
use_tokenizer_eos=True,
|
||||
patch_hf_runner=model_utils.skyworkr1v_patch_hf_runner,
|
||||
marks=[large_gpu_mark(min_gb=80)],
|
||||
),
|
||||
### Tensor parallel / multi-gpu broadcast tests
|
||||
"chameleon-broadcast": VLMTestInfo(
|
||||
models=["facebook/chameleon-7b"],
|
||||
|
@ -376,6 +376,63 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
return hf_model
|
||||
|
||||
|
||||
def skyworkr1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for SkyworkR1V."""
|
||||
|
||||
class SkyworkR1VProcessor:
|
||||
"""A simple processor for SkyworkR1V."""
|
||||
|
||||
def __init__(self, hf_runner: HfRunner):
|
||||
self.num_image_token = hf_runner.model.num_image_token
|
||||
self.tokenizer = hf_runner.tokenizer
|
||||
|
||||
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
|
||||
trust_remote_code=True)
|
||||
self.vision_config = self.config.vision_config
|
||||
self.use_thumbnail = self.config.use_thumbnail
|
||||
self.min_num = self.config.min_dynamic_patch
|
||||
self.max_num = self.config.max_dynamic_patch
|
||||
self.image_size = self.vision_config.image_size
|
||||
|
||||
def __call__(self, text: str, images: Union[Image, list[Image]],
|
||||
**kwargs):
|
||||
from vllm.model_executor.models.skyworkr1v import (
|
||||
IMG_CONTEXT, IMG_END, IMG_START,
|
||||
image_to_pixel_values_skyworkr1v)
|
||||
images = [images] if isinstance(images, Image) else images
|
||||
pixel_values = [
|
||||
image_to_pixel_values_skyworkr1v(
|
||||
image,
|
||||
input_size=self.image_size,
|
||||
min_num=self.min_num,
|
||||
max_num=self.max_num,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
) for image in images
|
||||
]
|
||||
num_patches_list = [
|
||||
pixel_value.shape[0] for pixel_value in pixel_values
|
||||
]
|
||||
pixel_values = torch.cat(pixel_values, dim=0)
|
||||
for num_patches in num_patches_list:
|
||||
context_tokens = IMG_CONTEXT * self.num_image_token \
|
||||
* num_patches
|
||||
image_tokens = IMG_START + context_tokens + IMG_END
|
||||
text = text.replace('<image>', image_tokens, 1)
|
||||
prompt = self.tokenizer(text, return_tensors="pt")
|
||||
prompt.update({"pixel_values": pixel_values})
|
||||
return prompt
|
||||
|
||||
img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
|
||||
"<IMG_CONTEXT>")
|
||||
hf_model.model.img_context_token_id = img_context_token_id
|
||||
hf_model.processor = SkyworkR1VProcessor(hf_model)
|
||||
hf_model.model.get_output_embeddings = lambda: \
|
||||
hf_model.model.language_model.get_output_embeddings()
|
||||
hf_model.model.generate = types.MethodType(_internvl_generate,
|
||||
hf_model.model)
|
||||
return hf_model
|
||||
|
||||
|
||||
def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for InternVL."""
|
||||
|
||||
|
@ -262,22 +262,23 @@ def _test_processing_correctness_mistral(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
"mistral-community/pixtral-12b",
|
||||
"openbmb/MiniCPM-Llama3-V-2_5",
|
||||
"openbmb/MiniCPM-o-2_6",
|
||||
"openbmb/MiniCPM-V-2_6",
|
||||
"allenai/Molmo-7B-D-0924",
|
||||
"allenai/Molmo-7B-O-0924",
|
||||
"nvidia/NVLM-D-72B",
|
||||
"google/paligemma-3b-mix-224",
|
||||
"google/paligemma2-3b-ft-docci-448",
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
"mistral-community/pixtral-12b",
|
||||
"Qwen/Qwen-VL-Chat",
|
||||
"Qwen/Qwen2-VL-2B-Instruct",
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
"Skywork/Skywork-R1V-38B",
|
||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||
"openai/whisper-large-v3",
|
||||
"google/paligemma-3b-mix-224",
|
||||
"google/paligemma2-3b-ft-docci-448",
|
||||
])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
|
@ -294,6 +294,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
||||
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
||||
min_transformers_version="4.49"), # noqa: E501
|
||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
# [Encoder-decoder]
|
||||
|
@ -496,7 +496,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
return self._cached_token_str(self._tokenizer,
|
||||
hf_config.image_token_index)
|
||||
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
|
||||
"NVLM_D", "h2ovl_chat"):
|
||||
"skywork_chat", "NVLM_D", "h2ovl_chat"):
|
||||
return "<image>"
|
||||
if model_type == "mllama":
|
||||
return "<|image|>"
|
||||
|
@ -190,6 +190,7 @@ _MULTIMODAL_MODELS = {
|
||||
# [Encoder-decoder]
|
||||
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
||||
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
|
||||
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
|
||||
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
|
||||
}
|
||||
|
||||
|
1014
vllm/model_executor/models/skyworkr1v.py
Normal file
1014
vllm/model_executor/models/skyworkr1v.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -37,8 +37,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
|
||||
MLPSpeculatorConfig, MPTConfig,
|
||||
NemotronConfig, NVLM_D_Config,
|
||||
Olmo2Config, RWConfig,
|
||||
SolarConfig, Telechat2Config,
|
||||
UltravoxConfig)
|
||||
SkyworkR1VChatConfig, SolarConfig,
|
||||
Telechat2Config, UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
@ -76,6 +76,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
"NVLM_D": NVLM_D_Config,
|
||||
"olmo2": Olmo2Config,
|
||||
"solar": SolarConfig,
|
||||
"skywork_chat": SkyworkR1VChatConfig,
|
||||
"telechat": Telechat2Config,
|
||||
"ultravox": UltravoxConfig,
|
||||
**_CONFIG_REGISTRY_OVERRIDE_HF
|
||||
|
@ -20,6 +20,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||
from vllm.transformers_utils.configs.olmo2 import Olmo2Config
|
||||
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
|
||||
from vllm.transformers_utils.configs.solar import SolarConfig
|
||||
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
@ -42,6 +43,7 @@ __all__ = [
|
||||
"NemotronConfig",
|
||||
"NVLM_D_Config",
|
||||
"Olmo2Config",
|
||||
"SkyworkR1VChatConfig",
|
||||
"SolarConfig",
|
||||
"Telechat2Config",
|
||||
"UltravoxConfig",
|
||||
|
53
vllm/transformers_utils/configs/skyworkr1v.py
Normal file
53
vllm/transformers_utils/configs/skyworkr1v.py
Normal file
@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/configuration_skywork_chat.py
|
||||
# --------------------------------------------------------
|
||||
# SkyworkR1V
|
||||
# Copyright (c) 2025 Skywork
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class SkyworkR1VChatConfig(PretrainedConfig):
|
||||
model_type = 'internvl_chat'
|
||||
is_composition = True
|
||||
|
||||
def __init__(self,
|
||||
vision_config=None,
|
||||
llm_config=None,
|
||||
use_backbone_lora=0,
|
||||
use_llm_lora=0,
|
||||
select_layer=-1,
|
||||
force_image_size=None,
|
||||
downsample_ratio=0.5,
|
||||
template=None,
|
||||
dynamic_image_size=False,
|
||||
use_thumbnail=False,
|
||||
ps_version='v1',
|
||||
min_dynamic_patch=1,
|
||||
max_dynamic_patch=6,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {}
|
||||
|
||||
if llm_config is None:
|
||||
llm_config = {}
|
||||
|
||||
self.vision_config = PretrainedConfig(**vision_config)
|
||||
self.text_config = PretrainedConfig(**llm_config)
|
||||
|
||||
self.use_backbone_lora = use_backbone_lora
|
||||
self.use_llm_lora = use_llm_lora
|
||||
self.select_layer = select_layer
|
||||
self.force_image_size = force_image_size
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.template = template
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail = use_thumbnail
|
||||
self.ps_version = ps_version # pixel shuffle version
|
||||
self.min_dynamic_patch = min_dynamic_patch
|
||||
self.max_dynamic_patch = max_dynamic_patch
|
Loading…
x
Reference in New Issue
Block a user