[Model][VLM] Add Kimi-VL model support (#16387)
Signed-off-by: courage17340 <courage17340@163.com>
This commit is contained in:
parent
7b5ecf79bd
commit
b1308b84a3
@ -886,6 +886,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
*
|
*
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
- * `KimiVLForConditionalGeneration`
|
||||||
|
* Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking
|
||||||
|
* T + I<sup>+</sup>
|
||||||
|
* `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking`
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* ✅︎
|
||||||
- * `Llama4ForConditionalGeneration`
|
- * `Llama4ForConditionalGeneration`
|
||||||
* Llama 4
|
* Llama 4
|
||||||
* T + I<sup>+</sup>
|
* T + I<sup>+</sup>
|
||||||
|
@ -364,6 +364,29 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Kimi-VL
|
||||||
|
def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||||
|
assert modality == "image"
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>"
|
||||||
|
f"<|media_pad|><|media_end|>{question}<|im_end|>"
|
||||||
|
"<|im_assistant|>assistant<|im_middle|>" for question in questions
|
||||||
|
]
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model="moonshotai/Kimi-VL-A3B-Instruct",
|
||||||
|
max_model_len=4096,
|
||||||
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompts=prompts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# LLaVA-1.5
|
# LLaVA-1.5
|
||||||
def run_llava(questions: list[str], modality: str) -> ModelRequestData:
|
def run_llava(questions: list[str], modality: str) -> ModelRequestData:
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
@ -966,6 +989,7 @@ model_example_map = {
|
|||||||
"h2ovl_chat": run_h2ovl,
|
"h2ovl_chat": run_h2ovl,
|
||||||
"idefics3": run_idefics3,
|
"idefics3": run_idefics3,
|
||||||
"internvl_chat": run_internvl,
|
"internvl_chat": run_internvl,
|
||||||
|
"kimi_vl": run_kimi_vl,
|
||||||
"llava": run_llava,
|
"llava": run_llava,
|
||||||
"llava-next": run_llava_next,
|
"llava-next": run_llava_next,
|
||||||
"llava-next-video": run_llava_next_video,
|
"llava-next-video": run_llava_next_video,
|
||||||
|
@ -326,6 +326,45 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
|
model_name = "moonshotai/Kimi-VL-A3B-Instruct"
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=4,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
*placeholders,
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": question
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name,
|
||||||
|
trust_remote_code=True)
|
||||||
|
|
||||||
|
prompt = processor.apply_chat_template(messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompt=prompt,
|
||||||
|
image_data=[fetch_image(url) for url in image_urls],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
|
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||||
|
|
||||||
@ -640,6 +679,7 @@ model_example_map = {
|
|||||||
"h2ovl_chat": load_h2ovl,
|
"h2ovl_chat": load_h2ovl,
|
||||||
"idefics3": load_idefics3,
|
"idefics3": load_idefics3,
|
||||||
"internvl_chat": load_internvl,
|
"internvl_chat": load_internvl,
|
||||||
|
"kimi_vl": load_kimi_vl,
|
||||||
"llama4": load_llama4,
|
"llama4": load_llama4,
|
||||||
"mistral3": load_mistral3,
|
"mistral3": load_mistral3,
|
||||||
"mllama": load_mllama,
|
"mllama": load_mllama,
|
||||||
|
@ -10,6 +10,7 @@ pytest-timeout
|
|||||||
# testing utils
|
# testing utils
|
||||||
awscli
|
awscli
|
||||||
backoff # required for phi4mm test
|
backoff # required for phi4mm test
|
||||||
|
blobfile # required for kimi-vl test
|
||||||
einops # required for MPT, qwen-vl and Mamba
|
einops # required for MPT, qwen-vl and Mamba
|
||||||
httpx
|
httpx
|
||||||
librosa # required for audio tests
|
librosa # required for audio tests
|
||||||
|
@ -39,6 +39,8 @@ bitsandbytes==0.45.3
|
|||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
black==24.10.0
|
black==24.10.0
|
||||||
# via datamodel-code-generator
|
# via datamodel-code-generator
|
||||||
|
blobfile==3.0.0
|
||||||
|
# via -r requirements/test.in
|
||||||
boto3==1.35.57
|
boto3==1.35.57
|
||||||
# via tensorizer
|
# via tensorizer
|
||||||
botocore==1.35.57
|
botocore==1.35.57
|
||||||
@ -127,6 +129,7 @@ fastsafetensors==0.1.10
|
|||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
filelock==3.16.1
|
filelock==3.16.1
|
||||||
# via
|
# via
|
||||||
|
# blobfile
|
||||||
# datasets
|
# datasets
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# ray
|
# ray
|
||||||
@ -227,7 +230,9 @@ llvmlite==0.44.0
|
|||||||
lm-eval==0.4.8
|
lm-eval==0.4.8
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
lxml==5.3.0
|
lxml==5.3.0
|
||||||
# via sacrebleu
|
# via
|
||||||
|
# blobfile
|
||||||
|
# sacrebleu
|
||||||
markdown-it-py==3.0.0
|
markdown-it-py==3.0.0
|
||||||
# via rich
|
# via rich
|
||||||
markupsafe==3.0.2
|
markupsafe==3.0.2
|
||||||
@ -426,6 +431,8 @@ pybind11==2.13.6
|
|||||||
# via lm-eval
|
# via lm-eval
|
||||||
pycparser==2.22
|
pycparser==2.22
|
||||||
# via cffi
|
# via cffi
|
||||||
|
pycryptodomex==3.22.0
|
||||||
|
# via blobfile
|
||||||
pydantic==2.9.2
|
pydantic==2.9.2
|
||||||
# via
|
# via
|
||||||
# datamodel-code-generator
|
# datamodel-code-generator
|
||||||
@ -689,6 +696,7 @@ tzdata==2024.2
|
|||||||
# via pandas
|
# via pandas
|
||||||
urllib3==2.2.3
|
urllib3==2.2.3
|
||||||
# via
|
# via
|
||||||
|
# blobfile
|
||||||
# botocore
|
# botocore
|
||||||
# requests
|
# requests
|
||||||
# responses
|
# responses
|
||||||
|
@ -318,6 +318,18 @@ VLM_TEST_SETTINGS = {
|
|||||||
use_tokenizer_eos=True,
|
use_tokenizer_eos=True,
|
||||||
patch_hf_runner=model_utils.internvl_patch_hf_runner,
|
patch_hf_runner=model_utils.internvl_patch_hf_runner,
|
||||||
),
|
),
|
||||||
|
"kimi_vl": VLMTestInfo(
|
||||||
|
models=["moonshotai/Kimi-VL-A3B-Instruct"],
|
||||||
|
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||||
|
prompt_formatter=lambda img_prompt: f"<|im_user|>user<|im_middle|>{img_prompt}<|im_end|><|im_assistant|>assistant<|im_middle|>", # noqa: E501
|
||||||
|
img_idx_to_prompt=lambda _: "<|media_start|>image<|media_content|><|media_pad|><|media_end|>", # noqa: E501
|
||||||
|
max_model_len=8192,
|
||||||
|
max_num_seqs=2,
|
||||||
|
dtype="bfloat16",
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
vllm_output_post_proc=model_utils.kimiv_vl_vllm_to_hf_output,
|
||||||
|
marks=[large_gpu_mark(min_gb=48)],
|
||||||
|
),
|
||||||
"llama4": VLMTestInfo(
|
"llama4": VLMTestInfo(
|
||||||
models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"],
|
models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"],
|
||||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501
|
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501
|
||||||
|
@ -68,6 +68,17 @@ def qwen2_vllm_to_hf_output(
|
|||||||
return output_ids, hf_output_str, out_logprobs
|
return output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def kimiv_vl_vllm_to_hf_output(
|
||||||
|
vllm_output: RunnerOutput,
|
||||||
|
model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]:
|
||||||
|
"""Sanitize vllm output [kimi_vl models] to be comparable with hf output."""
|
||||||
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
|
hf_output_str = output_str + "<|im_end|>[EOS]"
|
||||||
|
|
||||||
|
return output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput,
|
def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput,
|
||||||
model: str) -> RunnerOutput:
|
model: str) -> RunnerOutput:
|
||||||
config = AutoConfig.from_pretrained(model)
|
config = AutoConfig.from_pretrained(model)
|
||||||
|
@ -258,6 +258,7 @@ def _test_processing_correctness_mistral(
|
|||||||
"OpenGVLab/InternVL2-1B",
|
"OpenGVLab/InternVL2-1B",
|
||||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||||
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
||||||
|
"moonshotai/Kimi-VL-A3B-Instruct",
|
||||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
"llava-hf/llava-1.5-7b-hf",
|
"llava-hf/llava-1.5-7b-hf",
|
||||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||||
|
@ -302,6 +302,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
|
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
|
||||||
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
|
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
|
||||||
|
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
|
||||||
|
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
|
||||||
|
trust_remote_code=True),
|
||||||
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
|
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
|
||||||
min_transformers_version="4.51"),
|
min_transformers_version="4.51"),
|
||||||
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
|
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
|
||||||
|
@ -512,6 +512,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
return "<|fim_prefix|><|img|><|fim_suffix|>"
|
return "<|fim_prefix|><|img|><|fim_suffix|>"
|
||||||
if model_type == "gemma3":
|
if model_type == "gemma3":
|
||||||
return "<start_of_image>"
|
return "<start_of_image>"
|
||||||
|
if model_type == "kimi_vl":
|
||||||
|
return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501
|
||||||
|
|
||||||
raise TypeError(f"Unknown {modality} model type: {model_type}")
|
raise TypeError(f"Unknown {modality} model type: {model_type}")
|
||||||
elif modality == "audio":
|
elif modality == "audio":
|
||||||
|
608
vllm/model_executor/models/kimi_vl.py
Normal file
608
vllm/model_executor/models/kimi_vl.py
Normal file
@ -0,0 +1,608 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# ruff: noqa: E501
|
||||||
|
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
|
||||||
|
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
|
||||||
|
#
|
||||||
|
# Licensing Information:
|
||||||
|
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
|
||||||
|
# - Other parts of the code are licensed under the MIT License.
|
||||||
|
#
|
||||||
|
# Apache License, Version 2.0:
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# MIT License:
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (Any, Iterable, List, Literal, Optional, Sequence, Tuple,
|
||||||
|
TypedDict, Union)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import BatchFeature
|
||||||
|
from transformers.activations import GELUActivation
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
|
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
|
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
|
||||||
|
from vllm.model_executor.models.utils import merge_multimodal_embeddings
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||||
|
NestedTensors)
|
||||||
|
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||||
|
MultiModalDataItems)
|
||||||
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
BaseProcessingInfo, PromptReplacement,
|
||||||
|
PromptUpdate)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
|
||||||
|
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
||||||
|
|
||||||
|
from .utils import is_pp_missing_parameter, maybe_prefix
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# For dummy input only
|
||||||
|
@dataclass
|
||||||
|
class MaxImageTokenMeta:
|
||||||
|
width: int = 1024
|
||||||
|
height: int = 1024
|
||||||
|
|
||||||
|
|
||||||
|
class KimiVLMultiModalProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: KimiVLConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = (config.vision_config.hidden_size *
|
||||||
|
config.vision_config.merge_kernel_size[0] *
|
||||||
|
config.vision_config.merge_kernel_size[1])
|
||||||
|
|
||||||
|
self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size,
|
||||||
|
eps=1e-5)
|
||||||
|
self.linear_1 = nn.Linear(self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=True)
|
||||||
|
self.act = GELUActivation()
|
||||||
|
self.linear_2 = nn.Linear(self.hidden_size,
|
||||||
|
config.text_config.hidden_size,
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.pre_norm(image_features).view(
|
||||||
|
-1, self.hidden_size)
|
||||||
|
hidden_states = self.linear_1(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class KimiVLImagePixelInputs(TypedDict):
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
pixel_values: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
Shape:`(num_patches, num_channels, patch_size, patch_size)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_grid_hws: torch.Tensor
|
||||||
|
"""Shape:`(num_images, 2)`"""
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: support embeds too
|
||||||
|
# We only support pixel input for kimi-vl now
|
||||||
|
KimiVLImageInputs = KimiVLImagePixelInputs
|
||||||
|
|
||||||
|
|
||||||
|
class KimiVLProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
|
def get_hf_config(self):
|
||||||
|
return self.ctx.get_hf_config(KimiVLConfig)
|
||||||
|
|
||||||
|
def get_num_image_tokens(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
) -> int:
|
||||||
|
hf_processor = self.get_hf_processor()
|
||||||
|
patch_size = hf_processor.image_processor.patch_size
|
||||||
|
kernel_size = hf_processor.image_processor.merge_kernel_size
|
||||||
|
in_token_limit = hf_processor.image_processor.in_token_limit
|
||||||
|
height = image_height
|
||||||
|
width = image_width
|
||||||
|
assert isinstance(height,
|
||||||
|
int), f"height must be int, current height {height}"
|
||||||
|
assert isinstance(width,
|
||||||
|
int), f"width must be int, current width {width}"
|
||||||
|
assert kernel_size is not None, "kernel_size must be specified"
|
||||||
|
|
||||||
|
if (width // patch_size) * (height // patch_size) > in_token_limit:
|
||||||
|
scale = math.sqrt(in_token_limit / ((width // patch_size) *
|
||||||
|
(height // patch_size)))
|
||||||
|
new_w, new_h = int(width * scale), int(height * scale)
|
||||||
|
width, height = new_w, new_h
|
||||||
|
|
||||||
|
kernel_height, kernel_width = kernel_size
|
||||||
|
|
||||||
|
pad_height = (kernel_height * patch_size - height %
|
||||||
|
(kernel_height * patch_size)) % (kernel_height *
|
||||||
|
patch_size)
|
||||||
|
pad_width = (kernel_width * patch_size - width %
|
||||||
|
(kernel_width * patch_size)) % (kernel_width * patch_size)
|
||||||
|
|
||||||
|
# Calculate new dimensions after padding and patching
|
||||||
|
token_height = (height + pad_height) // (kernel_size[0] * patch_size)
|
||||||
|
token_width = (width + pad_width) // (kernel_size[1] * patch_size)
|
||||||
|
return int(token_height * token_width)
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
# None means unlimited
|
||||||
|
return {"image": None}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> Mapping[str, int]:
|
||||||
|
return {
|
||||||
|
"image":
|
||||||
|
self.get_num_image_tokens(
|
||||||
|
image_width=MaxImageTokenMeta.width,
|
||||||
|
image_height=MaxImageTokenMeta.height,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_token_id(self) -> int:
|
||||||
|
return self.get_hf_config().media_placeholder_token_id
|
||||||
|
|
||||||
|
|
||||||
|
class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]):
|
||||||
|
|
||||||
|
def __init__(self, info: KimiVLProcessingInfo) -> None:
|
||||||
|
super().__init__(info)
|
||||||
|
|
||||||
|
self.image_token_id = self.info.image_token_id
|
||||||
|
self.image_token = self.info.get_tokenizer().decode(
|
||||||
|
self.image_token_id)
|
||||||
|
|
||||||
|
def get_dummy_processor_inputs(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> ProcessorInputs:
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
|
width = MaxImageTokenMeta.width
|
||||||
|
height = MaxImageTokenMeta.height
|
||||||
|
mm_data = {
|
||||||
|
"image":
|
||||||
|
self._get_dummy_images(width=width,
|
||||||
|
height=height,
|
||||||
|
num_images=num_images)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ProcessorInputs(
|
||||||
|
prompt_text=self.image_token * num_images,
|
||||||
|
mm_data=mm_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
image_grid_hws = hf_inputs.get("image_grid_hws", torch.empty((0, 2)))
|
||||||
|
image_grid_sizes = image_grid_hws.prod(-1)
|
||||||
|
|
||||||
|
# pixel_values is merged as a single large tensor
|
||||||
|
# image_grid_hws is shapes for each subtensor in pixel_values
|
||||||
|
return dict(
|
||||||
|
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"image", image_grid_sizes),
|
||||||
|
image_grid_hws=MultiModalFieldConfig.batched("image"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> Sequence[PromptUpdate]:
|
||||||
|
image_token_id = self.info.image_token_id
|
||||||
|
|
||||||
|
def get_replacement(item_idx: int):
|
||||||
|
images = mm_items.get_items(
|
||||||
|
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||||
|
|
||||||
|
if isinstance(images, ImageEmbeddingItems):
|
||||||
|
num_image_tokens = images.get_feature_size(item_idx)
|
||||||
|
else:
|
||||||
|
image_size = images.get_image_size(item_idx)
|
||||||
|
num_image_tokens = self.info.get_num_image_tokens(
|
||||||
|
image_width=image_size.width,
|
||||||
|
image_height=image_size.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [image_token_id] * num_image_tokens
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality="image",
|
||||||
|
target=[image_token_id],
|
||||||
|
replacement=get_replacement,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor,
|
||||||
|
info=KimiVLProcessingInfo,
|
||||||
|
dummy_inputs=KimiVLDummyInputsBuilder)
|
||||||
|
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
config: KimiVLConfig = model_config.hf_config
|
||||||
|
self.config = config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
assert isinstance(config.vision_config, MoonViTConfig)
|
||||||
|
|
||||||
|
self.vision_tower = MoonVitPretrainedModel(config.vision_config)
|
||||||
|
|
||||||
|
self.multi_modal_projector = KimiVLMultiModalProjector(config=config)
|
||||||
|
|
||||||
|
self.quant_config = quant_config
|
||||||
|
sub_vllm_config = copy.deepcopy(vllm_config)
|
||||||
|
sub_vllm_config.model_config.hf_config = sub_vllm_config.model_config.hf_config.text_config
|
||||||
|
self.language_model = DeepseekV2Model(
|
||||||
|
vllm_config=sub_vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "language_model"),
|
||||||
|
)
|
||||||
|
self.unpadded_vocab_size = config.text_config.vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.text_config.hidden_size,
|
||||||
|
org_num_embeddings=self.config.text_config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE)
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size, logit_scale)
|
||||||
|
self.sampler = get_sampler()
|
||||||
|
self.media_placeholder: int = self.config.media_placeholder_token_id
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
self.tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
# ref: qwen2_vl.py
|
||||||
|
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||||
|
name: str) -> torch.Tensor:
|
||||||
|
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||||
|
raise ValueError(f"Incorrect type of {name}. "
|
||||||
|
f"Got type: {type(mm_input)}")
|
||||||
|
if isinstance(mm_input, torch.Tensor):
|
||||||
|
if mm_input.ndim == 2:
|
||||||
|
return mm_input
|
||||||
|
if mm_input.ndim != 3:
|
||||||
|
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||||
|
f"Got ndim: {mm_input.ndim} "
|
||||||
|
f"(shape={mm_input.shape})")
|
||||||
|
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||||
|
else:
|
||||||
|
return torch.concat(mm_input)
|
||||||
|
|
||||||
|
def _parse_and_validate_image_input(
|
||||||
|
self, **kwargs: object) -> Optional[KimiVLImageInputs]:
|
||||||
|
# image input type must be pixel values now
|
||||||
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
image_grid_hws = kwargs.pop("image_grid_hws", None)
|
||||||
|
|
||||||
|
if pixel_values is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
image_grid_hws = self._validate_and_reshape_mm_tensor(
|
||||||
|
image_grid_hws, "image grid hws")
|
||||||
|
# pixel_values may have complex shapes
|
||||||
|
num_channels = 3
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
if isinstance(pixel_values, list):
|
||||||
|
pixel_values = torch.cat([
|
||||||
|
x.reshape(-1, num_channels, patch_size, patch_size)
|
||||||
|
for x in pixel_values
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
|
||||||
|
patch_size)
|
||||||
|
# fp32 -> bf16
|
||||||
|
pixel_values = pixel_values.to(torch.bfloat16)
|
||||||
|
# image_grid_hws.shape = (N, 2)
|
||||||
|
assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}"
|
||||||
|
|
||||||
|
return KimiVLImagePixelInputs(
|
||||||
|
type="pixel_values",
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_grid_hws=image_grid_hws,
|
||||||
|
)
|
||||||
|
|
||||||
|
# perform vt on processored pixel_values
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _process_image_pixels(self,
|
||||||
|
inputs: KimiVLImagePixelInputs) -> torch.Tensor:
|
||||||
|
assert self.vision_tower is not None
|
||||||
|
|
||||||
|
pixel_values = inputs["pixel_values"]
|
||||||
|
image_grid_hws = inputs["image_grid_hws"]
|
||||||
|
return self.vision_tower(pixel_values, image_grid_hws)
|
||||||
|
|
||||||
|
def _process_image_input(self,
|
||||||
|
image_input: KimiVLImageInputs) -> torch.Tensor:
|
||||||
|
assert image_input["type"] == "pixel_values"
|
||||||
|
image_features = self._process_image_pixels(image_input)
|
||||||
|
assert isinstance(image_features, list)
|
||||||
|
lengths = [x.shape[0] for x in image_features]
|
||||||
|
return self.multi_modal_projector(
|
||||||
|
torch.cat(image_features)).split(lengths)
|
||||||
|
|
||||||
|
def get_multimodal_embeddings(self,
|
||||||
|
**kwargs: object) -> Optional[NestedTensors]:
|
||||||
|
# Validate the multimodal input keyword arguments
|
||||||
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
if image_input is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Run multimodal inputs through encoder and projector
|
||||||
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
|
return vision_embeddings
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# `get_input_embeddings` should already be implemented for the language
|
||||||
|
# model as one of the requirements of basic vLLM model implementation.
|
||||||
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
if multimodal_embeddings is not None:
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids=input_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
multimodal_embeddings=multimodal_embeddings,
|
||||||
|
placeholder_token_id=self.config.media_placeholder_token_id)
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> SamplerOutput:
|
||||||
|
if intermediate_tensors is not None:
|
||||||
|
inputs_embeds = None
|
||||||
|
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||||
|
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||||
|
# condition is only for v0 compatibility.
|
||||||
|
elif inputs_embeds is None:
|
||||||
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
if image_input is None:
|
||||||
|
inputs_embeds = None
|
||||||
|
else:
|
||||||
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||||
|
image_embeds = self._process_image_input(image_input)
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
image_embeds,
|
||||||
|
placeholder_token_id=self.config.
|
||||||
|
media_placeholder_token_id,
|
||||||
|
)
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
hidden_states = self.language_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
**kwargs) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata, **kwargs)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
config = self.config.text_config
|
||||||
|
_KEYS_TO_MODIFY_MAPPING = {
|
||||||
|
"language_model.lm_head": "lm_head",
|
||||||
|
"language_model.model": "language_model",
|
||||||
|
}
|
||||||
|
# only doing this for language model part for now.
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
]
|
||||||
|
if not config.use_mla:
|
||||||
|
stacked_params_mapping += [
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
]
|
||||||
|
if getattr(config, "n_routed_experts", None):
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=config.n_routed_experts)
|
||||||
|
else:
|
||||||
|
expert_params_mapping = []
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for args in weights:
|
||||||
|
name, loaded_weight = args[:2]
|
||||||
|
kwargs = args[2] if len(args) > 2 else {}
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
spec_layer = get_spec_layer_idx_from_weight_name(config, name)
|
||||||
|
if spec_layer is not None:
|
||||||
|
continue # skip spec decode layers for main model
|
||||||
|
|
||||||
|
if ("rotary_emb.cos_cached" in name
|
||||||
|
or "rotary_emb.sin_cached" in name):
|
||||||
|
# Models trained using ColossalAI may include these tensors in
|
||||||
|
# the checkpoint. Skip them.
|
||||||
|
continue
|
||||||
|
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||||
|
if key_to_modify in name:
|
||||||
|
name = name.replace(key_to_modify, new_key)
|
||||||
|
use_default_weight_loading = False
|
||||||
|
if "vision" in name:
|
||||||
|
if self.vision_tower is not None:
|
||||||
|
# We only do sharding for language model and
|
||||||
|
# not vision model for now.
|
||||||
|
use_default_weight_loading = True
|
||||||
|
else:
|
||||||
|
for (param_name, weight_name,
|
||||||
|
shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||||
|
# Since we handle the experts below in expert_params_mapping,
|
||||||
|
# we need to skip here BEFORE we update the name, otherwise
|
||||||
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||||
|
# will then be updated below in expert_params_mapping
|
||||||
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||||
|
if (("mlp.experts." in name) and name not in params_dict):
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id, **kwargs)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for idx, (param_name, weight_name, expert_id,
|
||||||
|
shard_id) in enumerate(expert_params_mapping):
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
expert_id=expert_id,
|
||||||
|
shard_id=shard_id,
|
||||||
|
**kwargs)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
use_default_weight_loading = True
|
||||||
|
if use_default_weight_loading:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config,
|
||||||
|
weight_name: str) -> Optional[int]:
|
||||||
|
if hasattr(config,
|
||||||
|
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
|
||||||
|
> 0):
|
||||||
|
layer_idx = config.num_hidden_layers
|
||||||
|
for i in range(config.num_nextn_predict_layers):
|
||||||
|
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
|
||||||
|
return layer_idx + i
|
||||||
|
return None
|
628
vllm/model_executor/models/moonvit.py
Normal file
628
vllm/model_executor/models/moonvit.py
Normal file
@ -0,0 +1,628 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# ruff: noqa: E501
|
||||||
|
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
|
||||||
|
# This file is meant to be used in kimi_vl.py only
|
||||||
|
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
|
||||||
|
#
|
||||||
|
# Licensing Information:
|
||||||
|
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
|
||||||
|
# - Other parts of the code are licensed under the MIT License.
|
||||||
|
#
|
||||||
|
# Apache License, Version 2.0:
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# MIT License:
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
import math
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.activations import ACT2FN, PytorchGELUTanh
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
flash_attn_varlen_func = None
|
||||||
|
|
||||||
|
|
||||||
|
def multihead_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
q_cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
k_cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""Multi-head attention using flash attention 2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||||
|
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||||
|
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
|
||||||
|
The first element should be 0 and the last element should be q.shape[0].
|
||||||
|
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
|
||||||
|
The first element should be 0 and the last element should be k.shape[0].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
|
||||||
|
where dim = num_heads * head_dim
|
||||||
|
"""
|
||||||
|
# Unified format legal check
|
||||||
|
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
|
||||||
|
assert q_cu_seqlens[-1] == q.shape[
|
||||||
|
0], "q_cu_seqlens must sum to q.shape[0]"
|
||||||
|
assert (k_cu_seqlens[-1] == k.shape[0] ==
|
||||||
|
v.shape[0]), "k_cu_seqlens must sum to k.shape[0]"
|
||||||
|
assert q.dtype in [
|
||||||
|
torch.bfloat16,
|
||||||
|
torch.float16,
|
||||||
|
], f"unsupported dtype {q.dtype} for multihead attn"
|
||||||
|
|
||||||
|
max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
|
||||||
|
max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
|
||||||
|
attn_out = flash_attn_varlen_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
q_cu_seqlens,
|
||||||
|
k_cu_seqlens,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
attn_out = attn_out.flatten(start_dim=-2)
|
||||||
|
|
||||||
|
return attn_out
|
||||||
|
|
||||||
|
|
||||||
|
def sdpa_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
q_cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
k_cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""SDPA attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||||
|
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||||
|
"""
|
||||||
|
seq_length = q.shape[0]
|
||||||
|
attention_mask = torch.zeros([1, seq_length, seq_length],
|
||||||
|
device=q.device,
|
||||||
|
dtype=torch.bool)
|
||||||
|
for i in range(1, len(q_cu_seqlens)):
|
||||||
|
attention_mask[
|
||||||
|
...,
|
||||||
|
q_cu_seqlens[i - 1]:q_cu_seqlens[i],
|
||||||
|
q_cu_seqlens[i - 1]:q_cu_seqlens[i],
|
||||||
|
] = True
|
||||||
|
q = q.transpose(0, 1)
|
||||||
|
k = k.transpose(0, 1)
|
||||||
|
v = v.transpose(0, 1)
|
||||||
|
attn_output = F.scaled_dot_product_attention(q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attention_mask,
|
||||||
|
dropout_p=0.0)
|
||||||
|
attn_output = attn_output.transpose(0, 1)
|
||||||
|
attn_output = attn_output.reshape(seq_length, -1)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
VL_VISION_ATTENTION_FUNCTIONS = {
|
||||||
|
"flash_attention_2": multihead_attention,
|
||||||
|
"sdpa": sdpa_attention,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rope_input_validation(x, freqs_cis):
|
||||||
|
assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
|
||||||
|
assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
|
||||||
|
assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
|
||||||
|
assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(xq: torch.Tensor, xk: torch.Tensor,
|
||||||
|
freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args: (The leading dimensions of all inputs should be the same)
|
||||||
|
xq: query, tensor of shape (..., num_heads, head_dim)
|
||||||
|
xk: key, tensor of shape (..., num_heads, head_dim)
|
||||||
|
freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
|
||||||
|
Returns:
|
||||||
|
xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
|
||||||
|
"""
|
||||||
|
_apply_rope_input_validation(xq, freqs_cis)
|
||||||
|
_apply_rope_input_validation(xk, freqs_cis)
|
||||||
|
|
||||||
|
freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
|
||||||
|
# ..., num_heads, head_dim/2
|
||||||
|
xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
|
||||||
|
xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
|
||||||
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(
|
||||||
|
-2) # ..., num_heads, head_dim
|
||||||
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(
|
||||||
|
-2) # ..., num_heads, head_dim
|
||||||
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
class Learnable2DInterpPosEmb(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
dim: int,
|
||||||
|
interpolation_mode: str = "bicubic") -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.interpolation_mode = interpolation_mode
|
||||||
|
self.weight = nn.Parameter(torch.empty(height, width, dim))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.normal_(self.weight)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
|
||||||
|
pos_embs = []
|
||||||
|
for shape in grid_hws.tolist():
|
||||||
|
if shape == self.weight.shape[:-1]:
|
||||||
|
pos_embs.append(self.weight.flatten(end_dim=1))
|
||||||
|
else:
|
||||||
|
pos_embs.append(
|
||||||
|
F.interpolate(
|
||||||
|
self.weight.permute((2, 0, 1)).unsqueeze(0),
|
||||||
|
size=shape,
|
||||||
|
mode=self.interpolation_mode,
|
||||||
|
).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1))
|
||||||
|
out = x + torch.cat(pos_embs)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MoonVisionPatchEmbed(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
out_dim: int,
|
||||||
|
in_dim: int = 3,
|
||||||
|
patch_size: Union[int, Tuple[int, int]] = (14, 14),
|
||||||
|
pos_emb_height: int = 14,
|
||||||
|
pos_emb_width: int = 14,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(
|
||||||
|
patch_size,
|
||||||
|
(int, Sequence)), f"Invalid patch_size type: {type(patch_size)}"
|
||||||
|
if isinstance(patch_size, int):
|
||||||
|
patch_size = (patch_size, patch_size)
|
||||||
|
assert (len(patch_size) == 2
|
||||||
|
), f"Expected patch_size to be a tuple of 2, got {patch_size}"
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(in_dim,
|
||||||
|
out_dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size)
|
||||||
|
|
||||||
|
self.pos_emb = Learnable2DInterpPosEmb(height=pos_emb_height,
|
||||||
|
width=pos_emb_width,
|
||||||
|
dim=out_dim)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (L, Channels): input tensor
|
||||||
|
grid_hw (N, 2): grid height and width
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(L, Cout) tensor
|
||||||
|
"""
|
||||||
|
x = self.proj(x).view(x.size(0), -1)
|
||||||
|
# apply positional embedding
|
||||||
|
x = self.pos_emb(x, grid_hw)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Rope2DPosEmb(nn.Module):
|
||||||
|
"""2D rotary position embedding with multi-resolution support.
|
||||||
|
|
||||||
|
This class is intended to be used in the following way:
|
||||||
|
1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
|
||||||
|
2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
|
||||||
|
3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
|
||||||
|
The rope is shared across all attention layers and all heads.
|
||||||
|
|
||||||
|
Refs:
|
||||||
|
- RoFormer: https://arxiv.org/abs/2104.09864
|
||||||
|
- VisionLLaMA: https://arxiv.org/abs/2403.00522
|
||||||
|
- https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
|
||||||
|
max_height (int): the maximum height of the 2D grid
|
||||||
|
max_width (int): the maximum width of the 2D grid
|
||||||
|
theta_base (float): the base of the theta
|
||||||
|
device (str): the device to store the precomputed cis
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim: int,
|
||||||
|
max_height: int,
|
||||||
|
max_width: int,
|
||||||
|
theta_base=10000,
|
||||||
|
device="cuda"):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
assert self.dim % 4 == 0, "dim must be divisible by 4"
|
||||||
|
self.max_height = max_height
|
||||||
|
self.max_width = max_width
|
||||||
|
self.theta_base = theta_base
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def precomputed_freqs_cis(self) -> torch.Tensor:
|
||||||
|
"""Calculate the cis(freqs) for each position in the 2D grid.
|
||||||
|
|
||||||
|
Return: complex tensor of shape (max_height, max_width, dim//2) and value:
|
||||||
|
height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
|
||||||
|
weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
|
||||||
|
note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
|
||||||
|
"""
|
||||||
|
N = self.max_height * self.max_width
|
||||||
|
flat_pos = torch.arange(0, N).float().to(self.device)
|
||||||
|
x_pos = flat_pos % self.max_width
|
||||||
|
y_pos = flat_pos // self.max_width
|
||||||
|
dim_range = (torch.arange(0, self.dim,
|
||||||
|
4)[:(self.dim // 4)].float().to(self.device)
|
||||||
|
) # C/4
|
||||||
|
freqs = 1.0 / (self.theta_base**(dim_range / self.dim))
|
||||||
|
x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
|
||||||
|
y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
|
||||||
|
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
|
||||||
|
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
|
||||||
|
# N, C/4, 2
|
||||||
|
freqs_cis = torch.cat(
|
||||||
|
[x_cis.unsqueeze(dim=-1),
|
||||||
|
y_cis.unsqueeze(dim=-1)], dim=-1)
|
||||||
|
# max_height, max_width, C/2
|
||||||
|
freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples.
|
||||||
|
Returns:
|
||||||
|
freqs_cis: tensor of shape (sum(t * height * width), dim//2)
|
||||||
|
"""
|
||||||
|
shapes = grid_hws.tolist()
|
||||||
|
assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width
|
||||||
|
for h, w in shapes), (
|
||||||
|
shapes,
|
||||||
|
self.max_height,
|
||||||
|
self.max_width,
|
||||||
|
)
|
||||||
|
freqs_cis = torch.cat(
|
||||||
|
[
|
||||||
|
self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
|
||||||
|
for h, w in shapes
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
def get_freqs_cis_by_idx(self, pos_idx: torch.Tensor,
|
||||||
|
pos_idx_mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
|
||||||
|
pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
|
||||||
|
Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
|
||||||
|
Return:
|
||||||
|
freqs_cis: tensor of shape (..., dim//2)
|
||||||
|
"""
|
||||||
|
assert (pos_idx.shape[:-1] == pos_idx_mask.shape
|
||||||
|
and pos_idx.shape[-1] == 2 and pos_idx.ndim
|
||||||
|
== pos_idx_mask.ndim + 1), (pos_idx.shape, pos_idx_mask.shape)
|
||||||
|
assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype
|
||||||
|
|
||||||
|
shp = pos_idx_mask.shape + (self.dim // 2, ) # ..., head_dim/2
|
||||||
|
freqs_cis = torch.ones(shp, dtype=torch.complex64,
|
||||||
|
device=self.device) # ..., head_dim/2
|
||||||
|
freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[pos_idx[
|
||||||
|
..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]]
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
|
||||||
|
class MLP2(nn.Module):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dims: [in_dim, hidden_dim, out_dim]
|
||||||
|
bias: whether to use bias in linear layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dims: list[int], activation, bias=True):
|
||||||
|
super().__init__()
|
||||||
|
assert len(dims) == 3
|
||||||
|
self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
|
||||||
|
self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
|
||||||
|
self.activation = activation
|
||||||
|
for m in [self.fc0, self.fc1]:
|
||||||
|
nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.fc0(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
return self.fc1(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MoonVitEncoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
mlp_dim: int,
|
||||||
|
*,
|
||||||
|
attn_implementation: str = "sdpa",
|
||||||
|
activation=F.gelu,
|
||||||
|
attn_bias: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
|
||||||
|
self.attn_implementation = attn_implementation
|
||||||
|
# use fa2 in vllm by default
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
self.attn_implementation = "flash_attention_2"
|
||||||
|
|
||||||
|
self.norm0 = nn.LayerNorm(hidden_dim)
|
||||||
|
self.norm1 = nn.LayerNorm(hidden_dim)
|
||||||
|
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
|
||||||
|
self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
|
||||||
|
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
|
||||||
|
|
||||||
|
def attention_qkvpacked(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
rope_freqs_cis: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): (batch_size, seqlen, hidden_dim)
|
||||||
|
cu_seqlens (torch.Tensor):
|
||||||
|
"""
|
||||||
|
xqkv = self.wqkv(x)
|
||||||
|
|
||||||
|
qkv_shape = xqkv.size()[:-1] + (
|
||||||
|
3,
|
||||||
|
self.num_heads,
|
||||||
|
self.hidden_size_per_attention_head,
|
||||||
|
)
|
||||||
|
# xqkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||||
|
xqkv = xqkv.view(*qkv_shape)
|
||||||
|
xq, xk, xv = torch.unbind(xqkv, dim=-3)
|
||||||
|
|
||||||
|
xq, xk = apply_rope(xq, xk, rope_freqs_cis)
|
||||||
|
|
||||||
|
attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
|
||||||
|
attn_out = attn_func(xq,
|
||||||
|
xk,
|
||||||
|
xv,
|
||||||
|
q_cu_seqlens=cu_seqlens,
|
||||||
|
k_cu_seqlens=cu_seqlens)
|
||||||
|
|
||||||
|
attn_out = self.wo(attn_out)
|
||||||
|
return attn_out
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
rope_freqs_cis: Union[torch.Tensor, None] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm0(hidden_states)
|
||||||
|
attn_out = self.attention_qkvpacked(hidden_states,
|
||||||
|
cu_seqlens,
|
||||||
|
rope_freqs_cis=rope_freqs_cis)
|
||||||
|
hidden_states = residual + attn_out
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.mlp(self.norm1(hidden_states))
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MoonVitEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_dim: int,
|
||||||
|
num_layers: int,
|
||||||
|
block_cfg: dict,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.rope_2d = Rope2DPosEmb(
|
||||||
|
block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512)
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)])
|
||||||
|
self.final_layernorm = nn.LayerNorm(hidden_dim)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
|
grid_hw: torch.Tensor) -> torch.Tensor:
|
||||||
|
rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(
|
||||||
|
grid_hws=grid_hw)
|
||||||
|
|
||||||
|
lengths = torch.cat((
|
||||||
|
torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
|
||||||
|
grid_hw[:, 0] * grid_hw[:, 1],
|
||||||
|
))
|
||||||
|
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
|
||||||
|
|
||||||
|
for _, block in enumerate(self.blocks):
|
||||||
|
hidden_states = block(hidden_states,
|
||||||
|
cu_seqlens,
|
||||||
|
rope_freqs_cis=rope_freqs_cis)
|
||||||
|
|
||||||
|
hidden_states = self.final_layernorm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def patch_merger(
|
||||||
|
x: torch.Tensor,
|
||||||
|
grid_hw: torch.Tensor,
|
||||||
|
merge_kernel_size: list[int, int] = (2, 2),
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
d_model = x.size(-1)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
pre_sum = 0
|
||||||
|
for x_shape in grid_hw.tolist():
|
||||||
|
height, width = x_shape[0], x_shape[1]
|
||||||
|
# Get the current sequence
|
||||||
|
seq = x[pre_sum:pre_sum + height * width]
|
||||||
|
# Reshape along self.merge_kernel_size and concat to the last dimension
|
||||||
|
kernel_height, kernel_width = merge_kernel_size
|
||||||
|
new_height, new_width = height // kernel_height, width // kernel_width
|
||||||
|
reshaped_seq = seq.view(new_height, kernel_height, new_width,
|
||||||
|
kernel_width, d_model)
|
||||||
|
reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
padded_seq = reshaped_seq.view(new_height * new_width,
|
||||||
|
kernel_height * kernel_width, -1)
|
||||||
|
outputs.append(padded_seq)
|
||||||
|
pre_sum += height * width
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class MoonVitVLProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
merge_kernel_size: list[int, int],
|
||||||
|
hidden_act: str = "gelu",
|
||||||
|
ln_eps: float = 1e-5,
|
||||||
|
out_dim: int = 4096,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = in_channels * merge_kernel_size[
|
||||||
|
0] * merge_kernel_size[1]
|
||||||
|
|
||||||
|
self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
|
||||||
|
self.linear_1 = nn.Linear(self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=True)
|
||||||
|
self.act = ACT2FN[hidden_act]
|
||||||
|
self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
|
||||||
|
hidden_states = self.linear_1(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MoonVitPretrainedModel(PreTrainedModel):
|
||||||
|
config_class = MoonViTConfig
|
||||||
|
model_type = "moonvit"
|
||||||
|
_no_split_modules = ["PackingTransformer"]
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
|
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
config = deepcopy(config)
|
||||||
|
self.merge_kernel_size = config.merge_kernel_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.patch_embed = MoonVisionPatchEmbed(
|
||||||
|
out_dim=config.hidden_size,
|
||||||
|
patch_size=config.patch_size,
|
||||||
|
pos_emb_height=config.init_pos_emb_height,
|
||||||
|
pos_emb_width=config.init_pos_emb_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = MoonVitEncoder(
|
||||||
|
hidden_dim=config.hidden_size,
|
||||||
|
num_layers=config.num_hidden_layers,
|
||||||
|
block_cfg={
|
||||||
|
"num_heads": config.num_attention_heads,
|
||||||
|
"hidden_dim": config.hidden_size,
|
||||||
|
"mlp_dim": config.intermediate_size,
|
||||||
|
"activation": PytorchGELUTanh(),
|
||||||
|
"attn_bias": True,
|
||||||
|
"attn_implementation": config._attn_implementation,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.Tensor,
|
||||||
|
grid_hw: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pixel_values (torch.Tensor): The input pixel values.
|
||||||
|
grid_hw (torch.Tensor): The grid height and width.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The output tokens.
|
||||||
|
"""
|
||||||
|
hidden_states = self.patch_embed(pixel_values, grid_hw)
|
||||||
|
hidden_states = self.encoder(hidden_states, grid_hw)
|
||||||
|
hidden_states = patch_merger(hidden_states,
|
||||||
|
grid_hw,
|
||||||
|
merge_kernel_size=self.merge_kernel_size)
|
||||||
|
return hidden_states
|
@ -177,6 +177,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||||
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
||||||
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
|
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
|
||||||
|
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
|
||||||
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
||||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||||
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
||||||
|
@ -33,12 +33,13 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
|
|||||||
EAGLEConfig, ExaoneConfig,
|
EAGLEConfig, ExaoneConfig,
|
||||||
H2OVLChatConfig,
|
H2OVLChatConfig,
|
||||||
InternVLChatConfig, JAISConfig,
|
InternVLChatConfig, JAISConfig,
|
||||||
MedusaConfig, MllamaConfig,
|
KimiVLConfig, MedusaConfig,
|
||||||
MLPSpeculatorConfig, MPTConfig,
|
MllamaConfig, MLPSpeculatorConfig,
|
||||||
NemotronConfig, NVLM_D_Config,
|
MPTConfig, NemotronConfig,
|
||||||
Olmo2Config, RWConfig,
|
NVLM_D_Config, Olmo2Config,
|
||||||
SkyworkR1VChatConfig, SolarConfig,
|
RWConfig, SkyworkR1VChatConfig,
|
||||||
Telechat2Config, UltravoxConfig)
|
SolarConfig, Telechat2Config,
|
||||||
|
UltravoxConfig)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
from vllm.utils import resolve_obj_by_qualname
|
from vllm.utils import resolve_obj_by_qualname
|
||||||
@ -62,6 +63,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|||||||
"cohere2": Cohere2Config,
|
"cohere2": Cohere2Config,
|
||||||
"dbrx": DbrxConfig,
|
"dbrx": DbrxConfig,
|
||||||
"deepseek_vl_v2": DeepseekVLV2Config,
|
"deepseek_vl_v2": DeepseekVLV2Config,
|
||||||
|
"kimi_vl": KimiVLConfig,
|
||||||
"mpt": MPTConfig,
|
"mpt": MPTConfig,
|
||||||
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
||||||
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
||||||
|
@ -13,9 +13,11 @@ from vllm.transformers_utils.configs.falcon import RWConfig
|
|||||||
from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig
|
from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig
|
||||||
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
|
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
|
||||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||||
|
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
|
||||||
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||||
from vllm.transformers_utils.configs.mllama import MllamaConfig
|
from vllm.transformers_utils.configs.mllama import MllamaConfig
|
||||||
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
||||||
|
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
||||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||||
@ -40,6 +42,8 @@ __all__ = [
|
|||||||
"ExaoneConfig",
|
"ExaoneConfig",
|
||||||
"MllamaConfig",
|
"MllamaConfig",
|
||||||
"MLPSpeculatorConfig",
|
"MLPSpeculatorConfig",
|
||||||
|
"MoonViTConfig",
|
||||||
|
"KimiVLConfig",
|
||||||
"NemotronConfig",
|
"NemotronConfig",
|
||||||
"NVLM_D_Config",
|
"NVLM_D_Config",
|
||||||
"Olmo2Config",
|
"Olmo2Config",
|
||||||
|
36
vllm/transformers_utils/configs/kimi_vl.py
Normal file
36
vllm/transformers_utils/configs/kimi_vl.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
||||||
|
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
||||||
|
|
||||||
|
|
||||||
|
class KimiVLConfig(PretrainedConfig):
|
||||||
|
model_type = "kimi_vl"
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
vision_config: Optional[Union[dict, MoonViTConfig]] = None,
|
||||||
|
text_config: Optional[Union[dict, DeepseekV2Config]] = None,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
media_placeholder_token_id: int = 163605,
|
||||||
|
pad_token_id: int = 0,
|
||||||
|
**kwargs):
|
||||||
|
if vision_config is None:
|
||||||
|
vision_config = MoonViTConfig()
|
||||||
|
elif isinstance(vision_config, dict):
|
||||||
|
vision_config = MoonViTConfig(**vision_config)
|
||||||
|
self.vision_config = vision_config
|
||||||
|
|
||||||
|
if text_config is None:
|
||||||
|
text_config = DeepseekV2Config()
|
||||||
|
elif isinstance(text_config, dict):
|
||||||
|
text_config = DeepseekV2Config(**text_config)
|
||||||
|
self.text_config = text_config
|
||||||
|
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self.media_placeholder_token_id = media_placeholder_token_id
|
||||||
|
|
||||||
|
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
32
vllm/transformers_utils/configs/moonvit.py
Normal file
32
vllm/transformers_utils/configs/moonvit.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MoonViTConfig(PretrainedConfig):
|
||||||
|
model_type = "moonvit"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 14,
|
||||||
|
init_pos_emb_height: int = 64,
|
||||||
|
init_pos_emb_width: int = 64,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_hidden_layers: int = 27,
|
||||||
|
hidden_size: int = 1152,
|
||||||
|
intermediate_size: int = 4304,
|
||||||
|
merge_kernel_size: tuple[int, int] = (2, 2),
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.patch_size = patch_size
|
||||||
|
# Positional embedding config
|
||||||
|
self.init_pos_emb_height = init_pos_emb_height
|
||||||
|
self.init_pos_emb_width = init_pos_emb_width
|
||||||
|
# Transformer config
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
# Patch merger config
|
||||||
|
self.merge_kernel_size = merge_kernel_size
|
@ -997,13 +997,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Return empty ModelRunnerOutput if there's no work to do.
|
# Return empty ModelRunnerOutput if there's no work to do.
|
||||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
|
||||||
if self.is_multimodal_model:
|
|
||||||
# Run the multimodal encoder if any.
|
|
||||||
self._execute_mm_encoder(scheduler_output)
|
|
||||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
|
||||||
else:
|
|
||||||
mm_embeds = []
|
|
||||||
|
|
||||||
# Prepare the decoder inputs.
|
# Prepare the decoder inputs.
|
||||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||||
self._prepare_inputs(scheduler_output))
|
self._prepare_inputs(scheduler_output))
|
||||||
@ -1019,6 +1012,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_input_tokens = num_scheduled_tokens
|
num_input_tokens = num_scheduled_tokens
|
||||||
attn_metadata.num_input_tokens = num_input_tokens
|
attn_metadata.num_input_tokens = num_input_tokens
|
||||||
|
|
||||||
|
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||||
|
# modal outputs after that to ensure the correct order
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
# Run the multimodal encoder if any.
|
||||||
|
self._execute_mm_encoder(scheduler_output)
|
||||||
|
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||||
|
else:
|
||||||
|
mm_embeds = []
|
||||||
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||||
# embeddings), we always use embeddings (rather than token ids)
|
# embeddings), we always use embeddings (rather than token ids)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user