[Model] Support Skywork-R1V (#15397)
Signed-off-by: jiacai.liu <932997367@qq.com> Co-authored-by: jiacai.liu <932997367@qq.com>
This commit is contained in:
parent
c802f5430d
commit
de1cb38769
@ -921,6 +921,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
- * `SkyworkR1VChatModel`
|
||||||
|
* Skywork-R1V-38B
|
||||||
|
* T + I
|
||||||
|
* `Skywork/Skywork-R1V-38B`
|
||||||
|
*
|
||||||
|
* ✅︎
|
||||||
|
* ✅︎
|
||||||
- * `UltravoxModel`
|
- * `UltravoxModel`
|
||||||
* Ultravox
|
* Ultravox
|
||||||
* T + A<sup>E+</sup>
|
* T + A<sup>E+</sup>
|
||||||
|
@ -804,6 +804,41 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# SkyworkR1V
|
||||||
|
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
|
||||||
|
assert modality == "image"
|
||||||
|
|
||||||
|
model_name = "Skywork/Skywork-R1V-38B"
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=4096,
|
||||||
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||||
|
trust_remote_code=True)
|
||||||
|
messages = [[{
|
||||||
|
'role': 'user',
|
||||||
|
'content': f"<image>\n{question}"
|
||||||
|
}] for question in questions]
|
||||||
|
prompts = tokenizer.apply_chat_template(messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
|
||||||
|
# Stop tokens for SkyworkR1V
|
||||||
|
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py
|
||||||
|
stop_tokens = ["<|end▁of▁sentence|>", "<|endoftext|>"]
|
||||||
|
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||||
|
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompts=prompts,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
model_example_map = {
|
model_example_map = {
|
||||||
"aria": run_aria,
|
"aria": run_aria,
|
||||||
"blip-2": run_blip2,
|
"blip-2": run_blip2,
|
||||||
@ -834,6 +869,7 @@ model_example_map = {
|
|||||||
"qwen_vl": run_qwen_vl,
|
"qwen_vl": run_qwen_vl,
|
||||||
"qwen2_vl": run_qwen2_vl,
|
"qwen2_vl": run_qwen2_vl,
|
||||||
"qwen2_5_vl": run_qwen2_5_vl,
|
"qwen2_5_vl": run_qwen2_5_vl,
|
||||||
|
"skywork_chat": run_skyworkr1v,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -474,6 +474,20 @@ VLM_TEST_SETTINGS = {
|
|||||||
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
|
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
|
||||||
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
|
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
|
||||||
),
|
),
|
||||||
|
"skywork_r1v": VLMTestInfo(
|
||||||
|
models=["Skywork/Skywork-R1V-38B"],
|
||||||
|
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||||
|
prompt_formatter=lambda img_prompt: f"<|begin▁of▁sentence|><|User|>\n{img_prompt}<|Assistant|><think>\n", # noqa: E501
|
||||||
|
single_image_prompts=IMAGE_ASSETS.prompts({
|
||||||
|
"stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501
|
||||||
|
"cherry_blossom": "<image>\nWhat is the season?",
|
||||||
|
}),
|
||||||
|
multi_image_prompt="<image>\n<image>\nDescribe the two images in short.", # noqa: E501
|
||||||
|
max_model_len=4096,
|
||||||
|
use_tokenizer_eos=True,
|
||||||
|
patch_hf_runner=model_utils.skyworkr1v_patch_hf_runner,
|
||||||
|
marks=[large_gpu_mark(min_gb=80)],
|
||||||
|
),
|
||||||
### Tensor parallel / multi-gpu broadcast tests
|
### Tensor parallel / multi-gpu broadcast tests
|
||||||
"chameleon-broadcast": VLMTestInfo(
|
"chameleon-broadcast": VLMTestInfo(
|
||||||
models=["facebook/chameleon-7b"],
|
models=["facebook/chameleon-7b"],
|
||||||
|
@ -376,6 +376,63 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
|||||||
return hf_model
|
return hf_model
|
||||||
|
|
||||||
|
|
||||||
|
def skyworkr1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||||
|
"""Patches and returns an instance of the HfRunner to use for SkyworkR1V."""
|
||||||
|
|
||||||
|
class SkyworkR1VProcessor:
|
||||||
|
"""A simple processor for SkyworkR1V."""
|
||||||
|
|
||||||
|
def __init__(self, hf_runner: HfRunner):
|
||||||
|
self.num_image_token = hf_runner.model.num_image_token
|
||||||
|
self.tokenizer = hf_runner.tokenizer
|
||||||
|
|
||||||
|
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
|
||||||
|
trust_remote_code=True)
|
||||||
|
self.vision_config = self.config.vision_config
|
||||||
|
self.use_thumbnail = self.config.use_thumbnail
|
||||||
|
self.min_num = self.config.min_dynamic_patch
|
||||||
|
self.max_num = self.config.max_dynamic_patch
|
||||||
|
self.image_size = self.vision_config.image_size
|
||||||
|
|
||||||
|
def __call__(self, text: str, images: Union[Image, list[Image]],
|
||||||
|
**kwargs):
|
||||||
|
from vllm.model_executor.models.skyworkr1v import (
|
||||||
|
IMG_CONTEXT, IMG_END, IMG_START,
|
||||||
|
image_to_pixel_values_skyworkr1v)
|
||||||
|
images = [images] if isinstance(images, Image) else images
|
||||||
|
pixel_values = [
|
||||||
|
image_to_pixel_values_skyworkr1v(
|
||||||
|
image,
|
||||||
|
input_size=self.image_size,
|
||||||
|
min_num=self.min_num,
|
||||||
|
max_num=self.max_num,
|
||||||
|
use_thumbnail=self.use_thumbnail,
|
||||||
|
) for image in images
|
||||||
|
]
|
||||||
|
num_patches_list = [
|
||||||
|
pixel_value.shape[0] for pixel_value in pixel_values
|
||||||
|
]
|
||||||
|
pixel_values = torch.cat(pixel_values, dim=0)
|
||||||
|
for num_patches in num_patches_list:
|
||||||
|
context_tokens = IMG_CONTEXT * self.num_image_token \
|
||||||
|
* num_patches
|
||||||
|
image_tokens = IMG_START + context_tokens + IMG_END
|
||||||
|
text = text.replace('<image>', image_tokens, 1)
|
||||||
|
prompt = self.tokenizer(text, return_tensors="pt")
|
||||||
|
prompt.update({"pixel_values": pixel_values})
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
|
||||||
|
"<IMG_CONTEXT>")
|
||||||
|
hf_model.model.img_context_token_id = img_context_token_id
|
||||||
|
hf_model.processor = SkyworkR1VProcessor(hf_model)
|
||||||
|
hf_model.model.get_output_embeddings = lambda: \
|
||||||
|
hf_model.model.language_model.get_output_embeddings()
|
||||||
|
hf_model.model.generate = types.MethodType(_internvl_generate,
|
||||||
|
hf_model.model)
|
||||||
|
return hf_model
|
||||||
|
|
||||||
|
|
||||||
def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||||
"""Patches and returns an instance of the HfRunner to use for InternVL."""
|
"""Patches and returns an instance of the HfRunner to use for InternVL."""
|
||||||
|
|
||||||
|
@ -262,22 +262,23 @@ def _test_processing_correctness_mistral(
|
|||||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||||
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||||
"mistralai/Pixtral-12B-2409",
|
|
||||||
"mistral-community/pixtral-12b",
|
|
||||||
"openbmb/MiniCPM-Llama3-V-2_5",
|
"openbmb/MiniCPM-Llama3-V-2_5",
|
||||||
"openbmb/MiniCPM-o-2_6",
|
"openbmb/MiniCPM-o-2_6",
|
||||||
"openbmb/MiniCPM-V-2_6",
|
"openbmb/MiniCPM-V-2_6",
|
||||||
"allenai/Molmo-7B-D-0924",
|
"allenai/Molmo-7B-D-0924",
|
||||||
"allenai/Molmo-7B-O-0924",
|
"allenai/Molmo-7B-O-0924",
|
||||||
"nvidia/NVLM-D-72B",
|
"nvidia/NVLM-D-72B",
|
||||||
|
"google/paligemma-3b-mix-224",
|
||||||
|
"google/paligemma2-3b-ft-docci-448",
|
||||||
|
"mistralai/Pixtral-12B-2409",
|
||||||
|
"mistral-community/pixtral-12b",
|
||||||
"Qwen/Qwen-VL-Chat",
|
"Qwen/Qwen-VL-Chat",
|
||||||
"Qwen/Qwen2-VL-2B-Instruct",
|
"Qwen/Qwen2-VL-2B-Instruct",
|
||||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||||
|
"Skywork/Skywork-R1V-38B",
|
||||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||||
"openai/whisper-large-v3",
|
"openai/whisper-large-v3",
|
||||||
"google/paligemma-3b-mix-224",
|
|
||||||
"google/paligemma2-3b-ft-docci-448",
|
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||||
@pytest.mark.parametrize("num_batches", [32])
|
@pytest.mark.parametrize("num_batches", [32])
|
||||||
|
@ -294,6 +294,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
||||||
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
||||||
min_transformers_version="4.49"), # noqa: E501
|
min_transformers_version="4.49"), # noqa: E501
|
||||||
|
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
|
||||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
|
@ -496,7 +496,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
return self._cached_token_str(self._tokenizer,
|
return self._cached_token_str(self._tokenizer,
|
||||||
hf_config.image_token_index)
|
hf_config.image_token_index)
|
||||||
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
|
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
|
||||||
"NVLM_D", "h2ovl_chat"):
|
"skywork_chat", "NVLM_D", "h2ovl_chat"):
|
||||||
return "<image>"
|
return "<image>"
|
||||||
if model_type == "mllama":
|
if model_type == "mllama":
|
||||||
return "<|image|>"
|
return "<|image|>"
|
||||||
|
@ -190,6 +190,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
||||||
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
|
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
|
||||||
|
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
|
||||||
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
|
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
|
||||||
}
|
}
|
||||||
|
|
||||||
|
1014
vllm/model_executor/models/skyworkr1v.py
Normal file
1014
vllm/model_executor/models/skyworkr1v.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -37,8 +37,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
|
|||||||
MLPSpeculatorConfig, MPTConfig,
|
MLPSpeculatorConfig, MPTConfig,
|
||||||
NemotronConfig, NVLM_D_Config,
|
NemotronConfig, NVLM_D_Config,
|
||||||
Olmo2Config, RWConfig,
|
Olmo2Config, RWConfig,
|
||||||
SolarConfig, Telechat2Config,
|
SkyworkR1VChatConfig, SolarConfig,
|
||||||
UltravoxConfig)
|
Telechat2Config, UltravoxConfig)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
from vllm.utils import resolve_obj_by_qualname
|
from vllm.utils import resolve_obj_by_qualname
|
||||||
@ -76,6 +76,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|||||||
"NVLM_D": NVLM_D_Config,
|
"NVLM_D": NVLM_D_Config,
|
||||||
"olmo2": Olmo2Config,
|
"olmo2": Olmo2Config,
|
||||||
"solar": SolarConfig,
|
"solar": SolarConfig,
|
||||||
|
"skywork_chat": SkyworkR1VChatConfig,
|
||||||
"telechat": Telechat2Config,
|
"telechat": Telechat2Config,
|
||||||
"ultravox": UltravoxConfig,
|
"ultravox": UltravoxConfig,
|
||||||
**_CONFIG_REGISTRY_OVERRIDE_HF
|
**_CONFIG_REGISTRY_OVERRIDE_HF
|
||||||
|
@ -20,6 +20,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
|
|||||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||||
from vllm.transformers_utils.configs.olmo2 import Olmo2Config
|
from vllm.transformers_utils.configs.olmo2 import Olmo2Config
|
||||||
|
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
|
||||||
from vllm.transformers_utils.configs.solar import SolarConfig
|
from vllm.transformers_utils.configs.solar import SolarConfig
|
||||||
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
|
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
|
||||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||||
@ -42,6 +43,7 @@ __all__ = [
|
|||||||
"NemotronConfig",
|
"NemotronConfig",
|
||||||
"NVLM_D_Config",
|
"NVLM_D_Config",
|
||||||
"Olmo2Config",
|
"Olmo2Config",
|
||||||
|
"SkyworkR1VChatConfig",
|
||||||
"SolarConfig",
|
"SolarConfig",
|
||||||
"Telechat2Config",
|
"Telechat2Config",
|
||||||
"UltravoxConfig",
|
"UltravoxConfig",
|
||||||
|
53
vllm/transformers_utils/configs/skyworkr1v.py
Normal file
53
vllm/transformers_utils/configs/skyworkr1v.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# Adapted from
|
||||||
|
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/configuration_skywork_chat.py
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# SkyworkR1V
|
||||||
|
# Copyright (c) 2025 Skywork
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# --------------------------------------------------------
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SkyworkR1VChatConfig(PretrainedConfig):
|
||||||
|
model_type = 'internvl_chat'
|
||||||
|
is_composition = True
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
vision_config=None,
|
||||||
|
llm_config=None,
|
||||||
|
use_backbone_lora=0,
|
||||||
|
use_llm_lora=0,
|
||||||
|
select_layer=-1,
|
||||||
|
force_image_size=None,
|
||||||
|
downsample_ratio=0.5,
|
||||||
|
template=None,
|
||||||
|
dynamic_image_size=False,
|
||||||
|
use_thumbnail=False,
|
||||||
|
ps_version='v1',
|
||||||
|
min_dynamic_patch=1,
|
||||||
|
max_dynamic_patch=6,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
if vision_config is None:
|
||||||
|
vision_config = {}
|
||||||
|
|
||||||
|
if llm_config is None:
|
||||||
|
llm_config = {}
|
||||||
|
|
||||||
|
self.vision_config = PretrainedConfig(**vision_config)
|
||||||
|
self.text_config = PretrainedConfig(**llm_config)
|
||||||
|
|
||||||
|
self.use_backbone_lora = use_backbone_lora
|
||||||
|
self.use_llm_lora = use_llm_lora
|
||||||
|
self.select_layer = select_layer
|
||||||
|
self.force_image_size = force_image_size
|
||||||
|
self.downsample_ratio = downsample_ratio
|
||||||
|
self.template = template
|
||||||
|
self.dynamic_image_size = dynamic_image_size
|
||||||
|
self.use_thumbnail = use_thumbnail
|
||||||
|
self.ps_version = ps_version # pixel shuffle version
|
||||||
|
self.min_dynamic_patch = min_dynamic_patch
|
||||||
|
self.max_dynamic_patch = max_dynamic_patch
|
Loading…
x
Reference in New Issue
Block a user