[Model] Port deepseek-vl2 processor, remove dependency (#12169)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-01-18 13:59:39 +08:00 committed by GitHub
parent 813f249f02
commit 02798ecabe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 385 additions and 49 deletions

View File

@ -52,7 +52,6 @@ steps:
- tests/worker - tests/worker
- tests/standalone_tests/lazy_torch_compile.py - tests/standalone_tests/lazy_torch_compile.py
commands: commands:
- pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git # Used by multimoda processing test
- python3 standalone_tests/lazy_torch_compile.py - python3 standalone_tests/lazy_torch_compile.py
- pytest -v -s mq_llm_engine # MQLLMEngine - pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine - pytest -v -s async_engine # AsyncLLMEngine

View File

@ -767,16 +767,10 @@ See [this page](#generative-models) for more information on how to use generativ
<sup>E</sup> Pre-computed embeddings can be inputted for this modality. <sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>+</sup> Multiple items can be inputted per text prompt for this modality. <sup>+</sup> Multiple items can be inputted per text prompt for this modality.
````{note} ```{note}
To use `DeepSeek-VL2` series models, you need to install a fork version `deepseek_vl2` package: To use `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
```shell
pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git
``` ```
Besides, to run `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
````
```{note} ```{note}
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
``` ```

View File

@ -393,7 +393,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
model_example_map = { model_example_map = {
"aria": load_aria, "aria": load_aria,
"deepseek_vl2": load_deepseek_vl2, "deepseek_vl_v2": load_deepseek_vl2,
"h2ovl_chat": load_h2onvl, "h2ovl_chat": load_h2onvl,
"idefics3": load_idefics3, "idefics3": load_idefics3,
"internvl_chat": load_internvl, "internvl_chat": load_internvl,

View File

@ -190,7 +190,7 @@ VLM_TEST_SETTINGS = {
dtype="bfloat16", dtype="bfloat16",
), ),
"deepseek_vl_v2": VLMTestInfo( "deepseek_vl_v2": VLMTestInfo(
models=["deepseek-ai/deepseek-vl2-tiny"], models=["Isotr0py/deepseek-vl2-tiny"], # model repo using dynamic module
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501 prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501
max_model_len=4096, max_model_len=4096,

View File

@ -22,6 +22,8 @@ def _test_processing_correctness(
): ):
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3": if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]} hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
elif model_id == "deepseek-ai/deepseek-vl2-tiny":
hf_overrides = {"architectures": ["DeepseekVLV2ForCausalLM"]}
else: else:
hf_overrides = {} hf_overrides = {}
@ -139,6 +141,7 @@ def _test_processing_correctness(
("rhymes-ai/Aria", {"image": True}), ("rhymes-ai/Aria", {"image": True}),
("Salesforce/blip2-opt-2.7b", {"image": False}), ("Salesforce/blip2-opt-2.7b", {"image": False}),
("facebook/chameleon-7b", {"image": False}), ("facebook/chameleon-7b", {"image": False}),
("deepseek-ai/deepseek-vl2-tiny", {"image": True}),
("adept/fuyu-8b", {"image": False}), ("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}), ("llava-hf/llava-1.5-7b-hf", {"image": True}),
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}), ("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),

View File

@ -1,7 +1,7 @@
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py # adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" """Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
import math import math
from functools import cached_property, partial from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
@ -9,7 +9,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from transformers import AutoProcessor, BatchFeature, ProcessorMixin from transformers import BatchFeature
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -31,6 +31,8 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
MlpProjectorConfig, MlpProjectorConfig,
VisionEncoderConfig) VisionEncoderConfig)
from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor)
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
@ -129,25 +131,8 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(DeepseekVLV2Config) return self.ctx.get_hf_config(DeepseekVLV2Config)
def get_hf_processor(self) -> ProcessorMixin: def get_hf_processor(self) -> DeepseekVLV2Processor:
# TODO(Isotr0py): we should get rid of dependency on deepseek_vl2 return self.ctx.get_hf_processor(DeepseekVLV2Processor)
# in the future, because it's flasky and lack of maintenance.
try:
from deepseek_vl2.models.processing_deepseek_vl_v2 import (
DeepseekVLV2Processor, select_best_resolution)
AutoProcessor.register("DeepseekVLV2Processor",
DeepseekVLV2Processor)
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"You need to `pip install "
"git+https://github.com/deepseek-ai/DeepSeek-VL2.git` "
"to use this model") from exc
processor = self.ctx.get_hf_processor(DeepseekVLV2Processor)
processor.select_best_resolution = partial(
select_best_resolution,
candidate_resolutions=processor.candidate_resolutions)
return processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
@ -224,31 +209,21 @@ class DeepseekVL2MultiModalProcessor(
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
if mm_data: if mm_data:
outputs = self.info.ctx.call_hf_processor( processed_outputs = self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs), self.info.get_hf_processor(**mm_kwargs),
dict(prompt=prompt, **mm_data), dict(prompt=prompt, **mm_data),
mm_kwargs, mm_kwargs,
) )
# Deepseek-vl2 processor don't return BatchFeature,
# we need to manually create it
processed_outputs = dict(input_ids=outputs["input_ids"])
processed_outputs = BatchFeature(data=dict(processed_outputs),
tensor_type="pt")
# Remove batch dimension from processor outputs,
# because we will try batch to create NestedTensors
target_dtype = self.info.ctx.model_config.dtype target_dtype = self.info.ctx.model_config.dtype
pixel_values = outputs["images"].to(target_dtype).squeeze(0) pixel_values = processed_outputs.pop("pixel_values").to(
images_spatial_crop = outputs["images_spatial_crop"].squeeze(0) target_dtype)
# split pixel values into patches corresponding to each image
images_spatial_crop = processed_outputs["images_spatial_crop"]
patches_per_image = [ patches_per_image = [
x.prod().item() + 1 for x in images_spatial_crop x.prod().item() + 1 for x in images_spatial_crop
] ]
pixel_values = pixel_values.split(patches_per_image)
# Rename `images` -> `pixel_values` to avoid confusion processed_outputs["pixel_values"] = pixel_values
processed_outputs["pixel_values"] = list(
pixel_values.split(patches_per_image))
processed_outputs["images_spatial_crop"] = images_spatial_crop
else: else:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
processed_outputs = tokenizer(prompt, processed_outputs = tokenizer(prompt,

View File

@ -0,0 +1,4 @@
from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor)
__all__ = ["DeepseekVLV2Processor"]

View File

@ -0,0 +1,361 @@
# yapf: disable
# ruff: noqa: E501
# coding=utf-8
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/ff23960c5cf9e6874b44be38af930cfb0ccbb620/deepseek_vl2/models/processing_deepseek_vl_v2.py
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import math
from typing import List, Tuple
import torch
import torchvision.transforms as T
from PIL import Image, ImageOps
from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast
from transformers.processing_utils import ProcessorMixin
class ImageTransform:
def __init__(self,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True):
self.mean = mean
self.std = std
self.normalize = normalize
transform_pipelines = [T.ToTensor()]
if normalize:
transform_pipelines.append(T.Normalize(mean, std))
self.transform = T.Compose(transform_pipelines)
def __call__(self, pil_img: Image.Image):
x = self.transform(pil_img)
return x
class DeepseekVLV2Processor(ProcessorMixin):
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["tokenizer"]
def __init__(
self,
tokenizer: LlamaTokenizerFast,
candidate_resolutions: Tuple[Tuple[int, int]],
patch_size: int,
downsample_ratio: int,
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True,
image_token: str = "<image>",
pad_token: str = "<▁pad▁>",
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100,
**kwargs,
):
self.candidate_resolutions = candidate_resolutions
self.image_size = candidate_resolutions[0][0]
self.patch_size = patch_size
self.image_mean = image_mean
self.image_std = image_std
self.normalize = normalize
self.downsample_ratio = downsample_ratio
self.image_transform = ImageTransform(mean=image_mean, std=image_std, normalize=normalize)
self.tokenizer = tokenizer
self.tokenizer.padding_side = 'left' # must set thispadding side with make a difference in batch inference
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
if tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': pad_token})
# add image token
image_token_id = self.tokenizer.vocab.get(image_token)
if image_token_id is None:
special_tokens = [image_token]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token_id = self.tokenizer.vocab.get(image_token)
# add five special tokens for grounding-related tasks
# <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
special_tokens = ['<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>']
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
# add special tokens for SFT data
special_tokens = ["<|User|>", "<|Assistant|>"]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token = image_token
self.pad_token = pad_token
self.add_special_token = add_special_token
self.sft_format = sft_format
self.mask_prompt = mask_prompt
self.ignore_id = ignore_id
super().__init__(
tokenizer,
**kwargs,
)
def select_best_resolution(self, image_size):
# used for cropping
original_width, original_height = image_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in self.candidate_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(
original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height,
original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution
and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
@property
def bos_id(self):
return self.tokenizer.bos_token_id
@property
def eos_id(self):
return self.tokenizer.eos_token_id
@property
def pad_id(self):
return self.tokenizer.pad_token_id
def encode(self, text: str, bos: bool = True, eos: bool = False):
t = self.tokenizer.encode(text, add_special_tokens=False)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int], **kwargs) -> str:
return self.tokenizer.decode(t, **kwargs)
def process_one(
self,
prompt: str,
images: List[Image.Image],
inference_mode: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
inference_mode (bool): if True, then remove the last eos token;
system_prompt (str): the system prompt;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- pixel_values (torch.FloatTensor): [n_patches, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
assert (prompt is not None and images is not None
), "prompt and images must be used at the same time."
sft_format = prompt
tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens = self.tokenize_with_images(
sft_format, images, bos=True, eos=True, cropping=len(images) <= 2)
masked_tokenized_str = []
for token_index in tokenized_str:
if token_index != self.image_token_id:
masked_tokenized_str.append(token_index)
else:
masked_tokenized_str.append(self.ignore_id)
assert len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str), \
(f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal")
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) |
(input_ids == self.image_token_id)] = self.ignore_id
input_ids[input_ids < 0] = self.pad_id
if inference_mode:
# 去掉结尾的eos token
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.image_size, self.image_size))
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
else:
pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
input_ids = input_ids.unsqueeze(0)
prepare = BatchFeature(
data=dict(
input_ids=input_ids,
pixel_values=pixel_values,
images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop,
num_image_tokens=num_image_tokens,
),
tensor_type="pt",
)
return prepare
def __call__(
self,
*,
prompt: str,
images: List[Image.Image],
inference_mode: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
images (List[ImageType]): the list of images;
inference_mode (bool): if True, then remove the last eos token;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
prepare = self.process_one(
prompt=prompt,
images=images,
inference_mode=inference_mode,
)
return prepare
def tokenize_with_images(
self,
conversation: str,
images: List[Image.Image],
bos: bool = True,
eos: bool = True,
cropping: bool = True,
):
"""Tokenize text with <image> tags."""
assert conversation.count(self.image_token) == len(images)
text_splits = conversation.split(self.image_token)
images_list, images_seq_mask, images_spatial_crop = [], [], []
num_image_tokens = []
tokenized_str = []
for text_sep, image in zip(text_splits, images):
"""encode text_sep"""
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""select best resolution for anyres"""
if cropping:
best_width, best_height = self.select_best_resolution(image.size)
else:
best_width, best_height = self.image_size, self.image_size
"""process the global view"""
global_view = ImageOps.pad(image, (self.image_size, self.image_size),
color=tuple(int(x * 255) for x in self.image_transform.mean))
images_list.append(self.image_transform(global_view))
"""process the local views"""
local_view = ImageOps.pad(image, (best_width, best_height),
color=tuple(int(x * 255) for x in self.image_transform.mean))
for i in range(0, best_height, self.image_size):
for j in range(0, best_width, self.image_size):
images_list.append(
self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size))))
"""record height / width crop num"""
num_width_tiles, num_height_tiles = best_width // self.image_size, best_height // self.image_size
images_spatial_crop.append([num_width_tiles, num_height_tiles])
"""add image tokens"""
h = w = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
# global views tokens h * (w + 1), 1 is for line separator
tokenized_image = [self.image_token_id] * h * (w + 1)
# add a separator between global and local views
tokenized_image += [self.image_token_id]
# local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
tokenized_image += [self.image_token_id] * (num_height_tiles * h) * (num_width_tiles * w + 1)
tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image)
num_image_tokens.append(len(tokenized_image))
"""process the last text split"""
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""add the bos and eos tokens"""
if bos:
tokenized_str = [self.bos_id] + tokenized_str
images_seq_mask = [False] + images_seq_mask
if eos:
tokenized_str = tokenized_str + [self.eos_id]
images_seq_mask = images_seq_mask + [False]
assert len(tokenized_str) == len(
images_seq_mask), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
return tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens
AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor)