[Model][VLM] Add Kimi-VL model support (#16387)
Signed-off-by: courage17340 <courage17340@163.com>
This commit is contained in:
parent
7b5ecf79bd
commit
b1308b84a3
@ -886,6 +886,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `KimiVLForConditionalGeneration`
|
||||
* Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking
|
||||
* T + I<sup>+</sup>
|
||||
* `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking`
|
||||
*
|
||||
*
|
||||
* ✅︎
|
||||
- * `Llama4ForConditionalGeneration`
|
||||
* Llama 4
|
||||
* T + I<sup>+</sup>
|
||||
|
@ -364,6 +364,29 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Kimi-VL
|
||||
def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
prompts = [
|
||||
"<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>"
|
||||
f"<|media_pad|><|media_end|>{question}<|im_end|>"
|
||||
"<|im_assistant|>assistant<|im_middle|>" for question in questions
|
||||
]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="moonshotai/Kimi-VL-A3B-Instruct",
|
||||
max_model_len=4096,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# LLaVA-1.5
|
||||
def run_llava(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -966,6 +989,7 @@ model_example_map = {
|
||||
"h2ovl_chat": run_h2ovl,
|
||||
"idefics3": run_idefics3,
|
||||
"internvl_chat": run_internvl,
|
||||
"kimi_vl": run_kimi_vl,
|
||||
"llava": run_llava,
|
||||
"llava-next": run_llava_next,
|
||||
"llava-next-video": run_llava_next_video,
|
||||
|
@ -326,6 +326,45 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "moonshotai/Kimi-VL-A3B-Instruct"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=4,
|
||||
tensor_parallel_size=1,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": question
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
|
||||
prompt = processor.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
)
|
||||
|
||||
|
||||
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
|
||||
@ -640,6 +679,7 @@ model_example_map = {
|
||||
"h2ovl_chat": load_h2ovl,
|
||||
"idefics3": load_idefics3,
|
||||
"internvl_chat": load_internvl,
|
||||
"kimi_vl": load_kimi_vl,
|
||||
"llama4": load_llama4,
|
||||
"mistral3": load_mistral3,
|
||||
"mllama": load_mllama,
|
||||
|
@ -10,6 +10,7 @@ pytest-timeout
|
||||
# testing utils
|
||||
awscli
|
||||
backoff # required for phi4mm test
|
||||
blobfile # required for kimi-vl test
|
||||
einops # required for MPT, qwen-vl and Mamba
|
||||
httpx
|
||||
librosa # required for audio tests
|
||||
|
@ -39,6 +39,8 @@ bitsandbytes==0.45.3
|
||||
# via -r requirements/test.in
|
||||
black==24.10.0
|
||||
# via datamodel-code-generator
|
||||
blobfile==3.0.0
|
||||
# via -r requirements/test.in
|
||||
boto3==1.35.57
|
||||
# via tensorizer
|
||||
botocore==1.35.57
|
||||
@ -127,6 +129,7 @@ fastsafetensors==0.1.10
|
||||
# via -r requirements/test.in
|
||||
filelock==3.16.1
|
||||
# via
|
||||
# blobfile
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# ray
|
||||
@ -227,7 +230,9 @@ llvmlite==0.44.0
|
||||
lm-eval==0.4.8
|
||||
# via -r requirements/test.in
|
||||
lxml==5.3.0
|
||||
# via sacrebleu
|
||||
# via
|
||||
# blobfile
|
||||
# sacrebleu
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.2
|
||||
@ -426,6 +431,8 @@ pybind11==2.13.6
|
||||
# via lm-eval
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pycryptodomex==3.22.0
|
||||
# via blobfile
|
||||
pydantic==2.9.2
|
||||
# via
|
||||
# datamodel-code-generator
|
||||
@ -689,6 +696,7 @@ tzdata==2024.2
|
||||
# via pandas
|
||||
urllib3==2.2.3
|
||||
# via
|
||||
# blobfile
|
||||
# botocore
|
||||
# requests
|
||||
# responses
|
||||
|
@ -318,6 +318,18 @@ VLM_TEST_SETTINGS = {
|
||||
use_tokenizer_eos=True,
|
||||
patch_hf_runner=model_utils.internvl_patch_hf_runner,
|
||||
),
|
||||
"kimi_vl": VLMTestInfo(
|
||||
models=["moonshotai/Kimi-VL-A3B-Instruct"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|im_user|>user<|im_middle|>{img_prompt}<|im_end|><|im_assistant|>assistant<|im_middle|>", # noqa: E501
|
||||
img_idx_to_prompt=lambda _: "<|media_start|>image<|media_content|><|media_pad|><|media_end|>", # noqa: E501
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
dtype="bfloat16",
|
||||
tensor_parallel_size=1,
|
||||
vllm_output_post_proc=model_utils.kimiv_vl_vllm_to_hf_output,
|
||||
marks=[large_gpu_mark(min_gb=48)],
|
||||
),
|
||||
"llama4": VLMTestInfo(
|
||||
models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"],
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501
|
||||
|
@ -68,6 +68,17 @@ def qwen2_vllm_to_hf_output(
|
||||
return output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
def kimiv_vl_vllm_to_hf_output(
|
||||
vllm_output: RunnerOutput,
|
||||
model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]:
|
||||
"""Sanitize vllm output [kimi_vl models] to be comparable with hf output."""
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
|
||||
hf_output_str = output_str + "<|im_end|>[EOS]"
|
||||
|
||||
return output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput,
|
||||
model: str) -> RunnerOutput:
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
|
@ -258,6 +258,7 @@ def _test_processing_correctness_mistral(
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
||||
"moonshotai/Kimi-VL-A3B-Instruct",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"llava-hf/llava-1.5-7b-hf",
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
|
@ -302,6 +302,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
|
||||
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
|
||||
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
|
||||
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
|
||||
min_transformers_version="4.51"),
|
||||
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
|
||||
|
@ -512,6 +512,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
return "<|fim_prefix|><|img|><|fim_suffix|>"
|
||||
if model_type == "gemma3":
|
||||
return "<start_of_image>"
|
||||
if model_type == "kimi_vl":
|
||||
return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501
|
||||
|
||||
raise TypeError(f"Unknown {modality} model type: {model_type}")
|
||||
elif modality == "audio":
|
||||
|
608
vllm/model_executor/models/kimi_vl.py
Normal file
608
vllm/model_executor/models/kimi_vl.py
Normal file
@ -0,0 +1,608 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# ruff: noqa: E501
|
||||
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
|
||||
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
|
||||
#
|
||||
# Licensing Information:
|
||||
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
|
||||
# - Other parts of the code are licensed under the MIT License.
|
||||
#
|
||||
# Apache License, Version 2.0:
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License:
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import copy
|
||||
import math
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import (Any, Iterable, List, Literal, Optional, Sequence, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BatchFeature
|
||||
from transformers.activations import GELUActivation
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
|
||||
from vllm.model_executor.models.utils import merge_multimodal_embeddings
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
|
||||
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
||||
|
||||
from .utils import is_pp_missing_parameter, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# For dummy input only
|
||||
@dataclass
|
||||
class MaxImageTokenMeta:
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
|
||||
|
||||
class KimiVLMultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, config: KimiVLConfig):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = (config.vision_config.hidden_size *
|
||||
config.vision_config.merge_kernel_size[0] *
|
||||
config.vision_config.merge_kernel_size[1])
|
||||
|
||||
self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size,
|
||||
eps=1e-5)
|
||||
self.linear_1 = nn.Linear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True)
|
||||
self.act = GELUActivation()
|
||||
self.linear_2 = nn.Linear(self.hidden_size,
|
||||
config.text_config.hidden_size,
|
||||
bias=True)
|
||||
|
||||
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.pre_norm(image_features).view(
|
||||
-1, self.hidden_size)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class KimiVLImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
Shape:`(num_patches, num_channels, patch_size, patch_size)`
|
||||
"""
|
||||
|
||||
image_grid_hws: torch.Tensor
|
||||
"""Shape:`(num_images, 2)`"""
|
||||
|
||||
|
||||
# TODO: support embeds too
|
||||
# We only support pixel input for kimi-vl now
|
||||
KimiVLImageInputs = KimiVLImagePixelInputs
|
||||
|
||||
|
||||
class KimiVLProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(KimiVLConfig)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_processor = self.get_hf_processor()
|
||||
patch_size = hf_processor.image_processor.patch_size
|
||||
kernel_size = hf_processor.image_processor.merge_kernel_size
|
||||
in_token_limit = hf_processor.image_processor.in_token_limit
|
||||
height = image_height
|
||||
width = image_width
|
||||
assert isinstance(height,
|
||||
int), f"height must be int, current height {height}"
|
||||
assert isinstance(width,
|
||||
int), f"width must be int, current width {width}"
|
||||
assert kernel_size is not None, "kernel_size must be specified"
|
||||
|
||||
if (width // patch_size) * (height // patch_size) > in_token_limit:
|
||||
scale = math.sqrt(in_token_limit / ((width // patch_size) *
|
||||
(height // patch_size)))
|
||||
new_w, new_h = int(width * scale), int(height * scale)
|
||||
width, height = new_w, new_h
|
||||
|
||||
kernel_height, kernel_width = kernel_size
|
||||
|
||||
pad_height = (kernel_height * patch_size - height %
|
||||
(kernel_height * patch_size)) % (kernel_height *
|
||||
patch_size)
|
||||
pad_width = (kernel_width * patch_size - width %
|
||||
(kernel_width * patch_size)) % (kernel_width * patch_size)
|
||||
|
||||
# Calculate new dimensions after padding and patching
|
||||
token_height = (height + pad_height) // (kernel_size[0] * patch_size)
|
||||
token_width = (width + pad_width) // (kernel_size[1] * patch_size)
|
||||
return int(token_height * token_width)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
# None means unlimited
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {
|
||||
"image":
|
||||
self.get_num_image_tokens(
|
||||
image_width=MaxImageTokenMeta.width,
|
||||
image_height=MaxImageTokenMeta.height,
|
||||
),
|
||||
}
|
||||
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self.get_hf_config().media_placeholder_token_id
|
||||
|
||||
|
||||
class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]):
|
||||
|
||||
def __init__(self, info: KimiVLProcessingInfo) -> None:
|
||||
super().__init__(info)
|
||||
|
||||
self.image_token_id = self.info.image_token_id
|
||||
self.image_token = self.info.get_tokenizer().decode(
|
||||
self.image_token_id)
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
width = MaxImageTokenMeta.width
|
||||
height = MaxImageTokenMeta.height
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=width,
|
||||
height=height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=self.image_token * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_grid_hws = hf_inputs.get("image_grid_hws", torch.empty((0, 2)))
|
||||
image_grid_sizes = image_grid_hws.prod(-1)
|
||||
|
||||
# pixel_values is merged as a single large tensor
|
||||
# image_grid_hws is shapes for each subtensor in pixel_values
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes),
|
||||
image_grid_hws=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
image_token_id = self.info.image_token_id
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
num_image_tokens = images.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
num_image_tokens = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
return [image_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor,
|
||||
info=KimiVLProcessingInfo,
|
||||
dummy_inputs=KimiVLDummyInputsBuilder)
|
||||
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
model_config = vllm_config.model_config
|
||||
config: KimiVLConfig = model_config.hf_config
|
||||
self.config = config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
assert isinstance(config.vision_config, MoonViTConfig)
|
||||
|
||||
self.vision_tower = MoonVitPretrainedModel(config.vision_config)
|
||||
|
||||
self.multi_modal_projector = KimiVLMultiModalProjector(config=config)
|
||||
|
||||
self.quant_config = quant_config
|
||||
sub_vllm_config = copy.deepcopy(vllm_config)
|
||||
sub_vllm_config.model_config.hf_config = sub_vllm_config.model_config.hf_config.text_config
|
||||
self.language_model = DeepseekV2Model(
|
||||
vllm_config=sub_vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.unpadded_vocab_size = config.text_config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.text_config.hidden_size,
|
||||
org_num_embeddings=self.config.text_config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size, logit_scale)
|
||||
self.sampler = get_sampler()
|
||||
self.media_placeholder: int = self.config.media_placeholder_token_id
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_world_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# ref: qwen2_vl.py
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
f"Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
if mm_input.ndim == 2:
|
||||
return mm_input
|
||||
if mm_input.ndim != 3:
|
||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})")
|
||||
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[KimiVLImageInputs]:
|
||||
# image input type must be pixel values now
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_grid_hws = kwargs.pop("image_grid_hws", None)
|
||||
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
image_grid_hws = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_hws, "image grid hws")
|
||||
# pixel_values may have complex shapes
|
||||
num_channels = 3
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
if isinstance(pixel_values, list):
|
||||
pixel_values = torch.cat([
|
||||
x.reshape(-1, num_channels, patch_size, patch_size)
|
||||
for x in pixel_values
|
||||
])
|
||||
else:
|
||||
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
|
||||
patch_size)
|
||||
# fp32 -> bf16
|
||||
pixel_values = pixel_values.to(torch.bfloat16)
|
||||
# image_grid_hws.shape = (N, 2)
|
||||
assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}"
|
||||
|
||||
return KimiVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
image_grid_hws=image_grid_hws,
|
||||
)
|
||||
|
||||
# perform vt on processored pixel_values
|
||||
@torch.inference_mode()
|
||||
def _process_image_pixels(self,
|
||||
inputs: KimiVLImagePixelInputs) -> torch.Tensor:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["pixel_values"]
|
||||
image_grid_hws = inputs["image_grid_hws"]
|
||||
return self.vision_tower(pixel_values, image_grid_hws)
|
||||
|
||||
def _process_image_input(self,
|
||||
image_input: KimiVLImageInputs) -> torch.Tensor:
|
||||
assert image_input["type"] == "pixel_values"
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
assert isinstance(image_features, list)
|
||||
lengths = [x.shape[0] for x in image_features]
|
||||
return self.multi_modal_projector(
|
||||
torch.cat(image_features)).split(lengths)
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> Optional[NestedTensors]:
|
||||
# Validate the multimodal input keyword arguments
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
# Run multimodal inputs through encoder and projector
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# `get_input_embeddings` should already be implemented for the language
|
||||
# model as one of the requirements of basic vLLM model implementation.
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=self.config.media_placeholder_token_id)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||
# condition is only for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
image_embeds,
|
||||
placeholder_token_id=self.config.
|
||||
media_placeholder_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
**kwargs) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata, **kwargs)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
config = self.config.text_config
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"language_model.lm_head": "lm_head",
|
||||
"language_model.model": "language_model",
|
||||
}
|
||||
# only doing this for language model part for now.
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
if not config.use_mla:
|
||||
stacked_params_mapping += [
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
]
|
||||
if getattr(config, "n_routed_experts", None):
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=config.n_routed_experts)
|
||||
else:
|
||||
expert_params_mapping = []
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for args in weights:
|
||||
name, loaded_weight = args[:2]
|
||||
kwargs = args[2] if len(args) > 2 else {}
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(config, name)
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
use_default_weight_loading = False
|
||||
if "vision" in name:
|
||||
if self.vision_tower is not None:
|
||||
# We only do sharding for language model and
|
||||
# not vision model for now.
|
||||
use_default_weight_loading = True
|
||||
else:
|
||||
for (param_name, weight_name,
|
||||
shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id, **kwargs)
|
||||
break
|
||||
else:
|
||||
for idx, (param_name, weight_name, expert_id,
|
||||
shard_id) in enumerate(expert_params_mapping):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
expert_id=expert_id,
|
||||
shard_id=shard_id,
|
||||
**kwargs)
|
||||
break
|
||||
else:
|
||||
use_default_weight_loading = True
|
||||
if use_default_weight_loading:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight, **kwargs)
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config,
|
||||
weight_name: str) -> Optional[int]:
|
||||
if hasattr(config,
|
||||
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
|
||||
> 0):
|
||||
layer_idx = config.num_hidden_layers
|
||||
for i in range(config.num_nextn_predict_layers):
|
||||
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
|
||||
return layer_idx + i
|
||||
return None
|
628
vllm/model_executor/models/moonvit.py
Normal file
628
vllm/model_executor/models/moonvit.py
Normal file
@ -0,0 +1,628 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# ruff: noqa: E501
|
||||
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
|
||||
# This file is meant to be used in kimi_vl.py only
|
||||
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
|
||||
#
|
||||
# Licensing Information:
|
||||
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
|
||||
# - Other parts of the code are licensed under the MIT License.
|
||||
#
|
||||
# Apache License, Version 2.0:
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License:
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.activations import ACT2FN, PytorchGELUTanh
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
|
||||
def multihead_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
q_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
k_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Multi-head attention using flash attention 2.
|
||||
|
||||
Args:
|
||||
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
|
||||
The first element should be 0 and the last element should be q.shape[0].
|
||||
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
|
||||
The first element should be 0 and the last element should be k.shape[0].
|
||||
|
||||
Returns:
|
||||
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
|
||||
where dim = num_heads * head_dim
|
||||
"""
|
||||
# Unified format legal check
|
||||
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
|
||||
assert q_cu_seqlens[-1] == q.shape[
|
||||
0], "q_cu_seqlens must sum to q.shape[0]"
|
||||
assert (k_cu_seqlens[-1] == k.shape[0] ==
|
||||
v.shape[0]), "k_cu_seqlens must sum to k.shape[0]"
|
||||
assert q.dtype in [
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
], f"unsupported dtype {q.dtype} for multihead attn"
|
||||
|
||||
max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
|
||||
max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
|
||||
attn_out = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q_cu_seqlens,
|
||||
k_cu_seqlens,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
causal=False,
|
||||
)
|
||||
attn_out = attn_out.flatten(start_dim=-2)
|
||||
|
||||
return attn_out
|
||||
|
||||
|
||||
def sdpa_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
q_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
k_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""SDPA attention.
|
||||
|
||||
Args:
|
||||
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
"""
|
||||
seq_length = q.shape[0]
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length],
|
||||
device=q.device,
|
||||
dtype=torch.bool)
|
||||
for i in range(1, len(q_cu_seqlens)):
|
||||
attention_mask[
|
||||
...,
|
||||
q_cu_seqlens[i - 1]:q_cu_seqlens[i],
|
||||
q_cu_seqlens[i - 1]:q_cu_seqlens[i],
|
||||
] = True
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
attn_output = F.scaled_dot_product_attention(q,
|
||||
k,
|
||||
v,
|
||||
attention_mask,
|
||||
dropout_p=0.0)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
return attn_output
|
||||
|
||||
|
||||
VL_VISION_ATTENTION_FUNCTIONS = {
|
||||
"flash_attention_2": multihead_attention,
|
||||
"sdpa": sdpa_attention,
|
||||
}
|
||||
|
||||
|
||||
def _apply_rope_input_validation(x, freqs_cis):
|
||||
assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
|
||||
assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
|
||||
assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
|
||||
assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
|
||||
|
||||
|
||||
def apply_rope(xq: torch.Tensor, xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args: (The leading dimensions of all inputs should be the same)
|
||||
xq: query, tensor of shape (..., num_heads, head_dim)
|
||||
xk: key, tensor of shape (..., num_heads, head_dim)
|
||||
freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
|
||||
Returns:
|
||||
xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
|
||||
"""
|
||||
_apply_rope_input_validation(xq, freqs_cis)
|
||||
_apply_rope_input_validation(xk, freqs_cis)
|
||||
|
||||
freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
|
||||
# ..., num_heads, head_dim/2
|
||||
xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(
|
||||
-2) # ..., num_heads, head_dim
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(
|
||||
-2) # ..., num_heads, head_dim
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class Learnable2DInterpPosEmb(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
height: int,
|
||||
width: int,
|
||||
dim: int,
|
||||
interpolation_mode: str = "bicubic") -> None:
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.interpolation_mode = interpolation_mode
|
||||
self.weight = nn.Parameter(torch.empty(height, width, dim))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.normal_(self.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
|
||||
pos_embs = []
|
||||
for shape in grid_hws.tolist():
|
||||
if shape == self.weight.shape[:-1]:
|
||||
pos_embs.append(self.weight.flatten(end_dim=1))
|
||||
else:
|
||||
pos_embs.append(
|
||||
F.interpolate(
|
||||
self.weight.permute((2, 0, 1)).unsqueeze(0),
|
||||
size=shape,
|
||||
mode=self.interpolation_mode,
|
||||
).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1))
|
||||
out = x + torch.cat(pos_embs)
|
||||
return out
|
||||
|
||||
|
||||
class MoonVisionPatchEmbed(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_dim: int,
|
||||
in_dim: int = 3,
|
||||
patch_size: Union[int, Tuple[int, int]] = (14, 14),
|
||||
pos_emb_height: int = 14,
|
||||
pos_emb_width: int = 14,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(
|
||||
patch_size,
|
||||
(int, Sequence)), f"Invalid patch_size type: {type(patch_size)}"
|
||||
if isinstance(patch_size, int):
|
||||
patch_size = (patch_size, patch_size)
|
||||
assert (len(patch_size) == 2
|
||||
), f"Expected patch_size to be a tuple of 2, got {patch_size}"
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = nn.Conv2d(in_dim,
|
||||
out_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size)
|
||||
|
||||
self.pos_emb = Learnable2DInterpPosEmb(height=pos_emb_height,
|
||||
width=pos_emb_width,
|
||||
dim=out_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (L, Channels): input tensor
|
||||
grid_hw (N, 2): grid height and width
|
||||
|
||||
Returns:
|
||||
(L, Cout) tensor
|
||||
"""
|
||||
x = self.proj(x).view(x.size(0), -1)
|
||||
# apply positional embedding
|
||||
x = self.pos_emb(x, grid_hw)
|
||||
return x
|
||||
|
||||
|
||||
class Rope2DPosEmb(nn.Module):
|
||||
"""2D rotary position embedding with multi-resolution support.
|
||||
|
||||
This class is intended to be used in the following way:
|
||||
1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
|
||||
2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
|
||||
3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
|
||||
The rope is shared across all attention layers and all heads.
|
||||
|
||||
Refs:
|
||||
- RoFormer: https://arxiv.org/abs/2104.09864
|
||||
- VisionLLaMA: https://arxiv.org/abs/2403.00522
|
||||
- https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
|
||||
|
||||
Args:
|
||||
dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
|
||||
max_height (int): the maximum height of the 2D grid
|
||||
max_width (int): the maximum width of the 2D grid
|
||||
theta_base (float): the base of the theta
|
||||
device (str): the device to store the precomputed cis
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
max_height: int,
|
||||
max_width: int,
|
||||
theta_base=10000,
|
||||
device="cuda"):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert self.dim % 4 == 0, "dim must be divisible by 4"
|
||||
self.max_height = max_height
|
||||
self.max_width = max_width
|
||||
self.theta_base = theta_base
|
||||
self.device = device
|
||||
|
||||
def extra_repr(self):
|
||||
return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
|
||||
|
||||
@cached_property
|
||||
def precomputed_freqs_cis(self) -> torch.Tensor:
|
||||
"""Calculate the cis(freqs) for each position in the 2D grid.
|
||||
|
||||
Return: complex tensor of shape (max_height, max_width, dim//2) and value:
|
||||
height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
|
||||
weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
|
||||
note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
|
||||
"""
|
||||
N = self.max_height * self.max_width
|
||||
flat_pos = torch.arange(0, N).float().to(self.device)
|
||||
x_pos = flat_pos % self.max_width
|
||||
y_pos = flat_pos // self.max_width
|
||||
dim_range = (torch.arange(0, self.dim,
|
||||
4)[:(self.dim // 4)].float().to(self.device)
|
||||
) # C/4
|
||||
freqs = 1.0 / (self.theta_base**(dim_range / self.dim))
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
|
||||
# N, C/4, 2
|
||||
freqs_cis = torch.cat(
|
||||
[x_cis.unsqueeze(dim=-1),
|
||||
y_cis.unsqueeze(dim=-1)], dim=-1)
|
||||
# max_height, max_width, C/2
|
||||
freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
|
||||
return freqs_cis
|
||||
|
||||
def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples.
|
||||
Returns:
|
||||
freqs_cis: tensor of shape (sum(t * height * width), dim//2)
|
||||
"""
|
||||
shapes = grid_hws.tolist()
|
||||
assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width
|
||||
for h, w in shapes), (
|
||||
shapes,
|
||||
self.max_height,
|
||||
self.max_width,
|
||||
)
|
||||
freqs_cis = torch.cat(
|
||||
[
|
||||
self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
|
||||
for h, w in shapes
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
return freqs_cis
|
||||
|
||||
def get_freqs_cis_by_idx(self, pos_idx: torch.Tensor,
|
||||
pos_idx_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
|
||||
pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
|
||||
Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
|
||||
Return:
|
||||
freqs_cis: tensor of shape (..., dim//2)
|
||||
"""
|
||||
assert (pos_idx.shape[:-1] == pos_idx_mask.shape
|
||||
and pos_idx.shape[-1] == 2 and pos_idx.ndim
|
||||
== pos_idx_mask.ndim + 1), (pos_idx.shape, pos_idx_mask.shape)
|
||||
assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype
|
||||
|
||||
shp = pos_idx_mask.shape + (self.dim // 2, ) # ..., head_dim/2
|
||||
freqs_cis = torch.ones(shp, dtype=torch.complex64,
|
||||
device=self.device) # ..., head_dim/2
|
||||
freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[pos_idx[
|
||||
..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class MLP2(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
dims: [in_dim, hidden_dim, out_dim]
|
||||
bias: whether to use bias in linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: list[int], activation, bias=True):
|
||||
super().__init__()
|
||||
assert len(dims) == 3
|
||||
self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
|
||||
self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
|
||||
self.activation = activation
|
||||
for m in [self.fc0, self.fc1]:
|
||||
nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.fc0(x)
|
||||
x = self.activation(x)
|
||||
return self.fc1(x)
|
||||
|
||||
|
||||
class MoonVitEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
hidden_dim: int,
|
||||
mlp_dim: int,
|
||||
*,
|
||||
attn_implementation: str = "sdpa",
|
||||
activation=F.gelu,
|
||||
attn_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.hidden_dim = hidden_dim
|
||||
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
|
||||
self.attn_implementation = attn_implementation
|
||||
# use fa2 in vllm by default
|
||||
if is_flash_attn_2_available():
|
||||
self.attn_implementation = "flash_attention_2"
|
||||
|
||||
self.norm0 = nn.LayerNorm(hidden_dim)
|
||||
self.norm1 = nn.LayerNorm(hidden_dim)
|
||||
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
|
||||
self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
|
||||
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
|
||||
|
||||
def attention_qkvpacked(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rope_freqs_cis: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (batch_size, seqlen, hidden_dim)
|
||||
cu_seqlens (torch.Tensor):
|
||||
"""
|
||||
xqkv = self.wqkv(x)
|
||||
|
||||
qkv_shape = xqkv.size()[:-1] + (
|
||||
3,
|
||||
self.num_heads,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
# xqkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||
xqkv = xqkv.view(*qkv_shape)
|
||||
xq, xk, xv = torch.unbind(xqkv, dim=-3)
|
||||
|
||||
xq, xk = apply_rope(xq, xk, rope_freqs_cis)
|
||||
|
||||
attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
|
||||
attn_out = attn_func(xq,
|
||||
xk,
|
||||
xv,
|
||||
q_cu_seqlens=cu_seqlens,
|
||||
k_cu_seqlens=cu_seqlens)
|
||||
|
||||
attn_out = self.wo(attn_out)
|
||||
return attn_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rope_freqs_cis: Union[torch.Tensor, None] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set
|
||||
|
||||
Returns:
|
||||
output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm0(hidden_states)
|
||||
attn_out = self.attention_qkvpacked(hidden_states,
|
||||
cu_seqlens,
|
||||
rope_freqs_cis=rope_freqs_cis)
|
||||
hidden_states = residual + attn_out
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.mlp(self.norm1(hidden_states))
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MoonVitEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
num_layers: int,
|
||||
block_cfg: dict,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.rope_2d = Rope2DPosEmb(
|
||||
block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512)
|
||||
self.blocks = nn.ModuleList(
|
||||
[MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)])
|
||||
self.final_layernorm = nn.LayerNorm(hidden_dim)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
grid_hw: torch.Tensor) -> torch.Tensor:
|
||||
rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(
|
||||
grid_hws=grid_hw)
|
||||
|
||||
lengths = torch.cat((
|
||||
torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
|
||||
grid_hw[:, 0] * grid_hw[:, 1],
|
||||
))
|
||||
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
|
||||
|
||||
for _, block in enumerate(self.blocks):
|
||||
hidden_states = block(hidden_states,
|
||||
cu_seqlens,
|
||||
rope_freqs_cis=rope_freqs_cis)
|
||||
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def patch_merger(
|
||||
x: torch.Tensor,
|
||||
grid_hw: torch.Tensor,
|
||||
merge_kernel_size: list[int, int] = (2, 2),
|
||||
) -> List[torch.Tensor]:
|
||||
d_model = x.size(-1)
|
||||
|
||||
outputs = []
|
||||
pre_sum = 0
|
||||
for x_shape in grid_hw.tolist():
|
||||
height, width = x_shape[0], x_shape[1]
|
||||
# Get the current sequence
|
||||
seq = x[pre_sum:pre_sum + height * width]
|
||||
# Reshape along self.merge_kernel_size and concat to the last dimension
|
||||
kernel_height, kernel_width = merge_kernel_size
|
||||
new_height, new_width = height // kernel_height, width // kernel_width
|
||||
reshaped_seq = seq.view(new_height, kernel_height, new_width,
|
||||
kernel_width, d_model)
|
||||
reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous()
|
||||
padded_seq = reshaped_seq.view(new_height * new_width,
|
||||
kernel_height * kernel_width, -1)
|
||||
outputs.append(padded_seq)
|
||||
pre_sum += height * width
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class MoonVitVLProjector(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
merge_kernel_size: list[int, int],
|
||||
hidden_act: str = "gelu",
|
||||
ln_eps: float = 1e-5,
|
||||
out_dim: int = 4096,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = in_channels * merge_kernel_size[
|
||||
0] * merge_kernel_size[1]
|
||||
|
||||
self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
|
||||
self.linear_1 = nn.Linear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True)
|
||||
self.act = ACT2FN[hidden_act]
|
||||
self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MoonVitPretrainedModel(PreTrainedModel):
|
||||
config_class = MoonViTConfig
|
||||
model_type = "moonvit"
|
||||
_no_split_modules = ["PackingTransformer"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
config = deepcopy(config)
|
||||
self.merge_kernel_size = config.merge_kernel_size
|
||||
self.patch_size = config.patch_size
|
||||
self.patch_embed = MoonVisionPatchEmbed(
|
||||
out_dim=config.hidden_size,
|
||||
patch_size=config.patch_size,
|
||||
pos_emb_height=config.init_pos_emb_height,
|
||||
pos_emb_width=config.init_pos_emb_width,
|
||||
)
|
||||
|
||||
self.encoder = MoonVitEncoder(
|
||||
hidden_dim=config.hidden_size,
|
||||
num_layers=config.num_hidden_layers,
|
||||
block_cfg={
|
||||
"num_heads": config.num_attention_heads,
|
||||
"hidden_dim": config.hidden_size,
|
||||
"mlp_dim": config.intermediate_size,
|
||||
"activation": PytorchGELUTanh(),
|
||||
"attn_bias": True,
|
||||
"attn_implementation": config._attn_implementation,
|
||||
},
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor,
|
||||
grid_hw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
pixel_values (torch.Tensor): The input pixel values.
|
||||
grid_hw (torch.Tensor): The grid height and width.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tokens.
|
||||
"""
|
||||
hidden_states = self.patch_embed(pixel_values, grid_hw)
|
||||
hidden_states = self.encoder(hidden_states, grid_hw)
|
||||
hidden_states = patch_merger(hidden_states,
|
||||
grid_hw,
|
||||
merge_kernel_size=self.merge_kernel_size)
|
||||
return hidden_states
|
@ -177,6 +177,7 @@ _MULTIMODAL_MODELS = {
|
||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
||||
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
|
||||
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
|
||||
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
||||
|
@ -33,12 +33,13 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
|
||||
EAGLEConfig, ExaoneConfig,
|
||||
H2OVLChatConfig,
|
||||
InternVLChatConfig, JAISConfig,
|
||||
MedusaConfig, MllamaConfig,
|
||||
MLPSpeculatorConfig, MPTConfig,
|
||||
NemotronConfig, NVLM_D_Config,
|
||||
Olmo2Config, RWConfig,
|
||||
SkyworkR1VChatConfig, SolarConfig,
|
||||
Telechat2Config, UltravoxConfig)
|
||||
KimiVLConfig, MedusaConfig,
|
||||
MllamaConfig, MLPSpeculatorConfig,
|
||||
MPTConfig, NemotronConfig,
|
||||
NVLM_D_Config, Olmo2Config,
|
||||
RWConfig, SkyworkR1VChatConfig,
|
||||
SolarConfig, Telechat2Config,
|
||||
UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
@ -62,6 +63,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
"cohere2": Cohere2Config,
|
||||
"dbrx": DbrxConfig,
|
||||
"deepseek_vl_v2": DeepseekVLV2Config,
|
||||
"kimi_vl": KimiVLConfig,
|
||||
"mpt": MPTConfig,
|
||||
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
||||
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
||||
|
@ -13,9 +13,11 @@ from vllm.transformers_utils.configs.falcon import RWConfig
|
||||
from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig
|
||||
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
|
||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
|
||||
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||
from vllm.transformers_utils.configs.mllama import MllamaConfig
|
||||
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
||||
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||
@ -40,6 +42,8 @@ __all__ = [
|
||||
"ExaoneConfig",
|
||||
"MllamaConfig",
|
||||
"MLPSpeculatorConfig",
|
||||
"MoonViTConfig",
|
||||
"KimiVLConfig",
|
||||
"NemotronConfig",
|
||||
"NVLM_D_Config",
|
||||
"Olmo2Config",
|
||||
|
36
vllm/transformers_utils/configs/kimi_vl.py
Normal file
36
vllm/transformers_utils/configs/kimi_vl.py
Normal file
@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
||||
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
||||
|
||||
|
||||
class KimiVLConfig(PretrainedConfig):
|
||||
model_type = "kimi_vl"
|
||||
|
||||
def __init__(self,
|
||||
vision_config: Optional[Union[dict, MoonViTConfig]] = None,
|
||||
text_config: Optional[Union[dict, DeepseekV2Config]] = None,
|
||||
ignore_index: int = -100,
|
||||
media_placeholder_token_id: int = 163605,
|
||||
pad_token_id: int = 0,
|
||||
**kwargs):
|
||||
if vision_config is None:
|
||||
vision_config = MoonViTConfig()
|
||||
elif isinstance(vision_config, dict):
|
||||
vision_config = MoonViTConfig(**vision_config)
|
||||
self.vision_config = vision_config
|
||||
|
||||
if text_config is None:
|
||||
text_config = DeepseekV2Config()
|
||||
elif isinstance(text_config, dict):
|
||||
text_config = DeepseekV2Config(**text_config)
|
||||
self.text_config = text_config
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.media_placeholder_token_id = media_placeholder_token_id
|
||||
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
32
vllm/transformers_utils/configs/moonvit.py
Normal file
32
vllm/transformers_utils/configs/moonvit.py
Normal file
@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class MoonViTConfig(PretrainedConfig):
|
||||
model_type = "moonvit"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 14,
|
||||
init_pos_emb_height: int = 64,
|
||||
init_pos_emb_width: int = 64,
|
||||
num_attention_heads: int = 16,
|
||||
num_hidden_layers: int = 27,
|
||||
hidden_size: int = 1152,
|
||||
intermediate_size: int = 4304,
|
||||
merge_kernel_size: tuple[int, int] = (2, 2),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.patch_size = patch_size
|
||||
# Positional embedding config
|
||||
self.init_pos_emb_height = init_pos_emb_height
|
||||
self.init_pos_emb_width = init_pos_emb_width
|
||||
# Transformer config
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
# Patch merger config
|
||||
self.merge_kernel_size = merge_kernel_size
|
@ -997,13 +997,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
else:
|
||||
mm_embeds = []
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||
self._prepare_inputs(scheduler_output))
|
||||
@ -1019,6 +1012,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_input_tokens = num_scheduled_tokens
|
||||
attn_metadata.num_input_tokens = num_input_tokens
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
# modal outputs after that to ensure the correct order
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
else:
|
||||
mm_embeds = []
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
|
Loading…
x
Reference in New Issue
Block a user