From b1308b84a3a6323f33dc995923b6cffa86e7199e Mon Sep 17 00:00:00 2001 From: courage17340 Date: Tue, 15 Apr 2025 05:41:48 +0800 Subject: [PATCH] [Model][VLM] Add Kimi-VL model support (#16387) Signed-off-by: courage17340 --- docs/source/models/supported_models.md | 7 + examples/offline_inference/vision_language.py | 24 + .../vision_language_multi_image.py | 40 ++ requirements/test.in | 1 + requirements/test.txt | 10 +- .../vision_language/test_models.py | 12 + .../vision_language/vlm_utils/model_utils.py | 11 + .../multimodal/processing/test_common.py | 1 + tests/models/registry.py | 3 + vllm/entrypoints/chat_utils.py | 2 + vllm/model_executor/models/kimi_vl.py | 608 +++++++++++++++++ vllm/model_executor/models/moonvit.py | 628 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/config.py | 14 +- vllm/transformers_utils/configs/__init__.py | 4 + vllm/transformers_utils/configs/kimi_vl.py | 36 + vllm/transformers_utils/configs/moonvit.py | 32 + vllm/v1/worker/gpu_model_runner.py | 16 +- 18 files changed, 1436 insertions(+), 14 deletions(-) create mode 100644 vllm/model_executor/models/kimi_vl.py create mode 100644 vllm/model_executor/models/moonvit.py create mode 100644 vllm/transformers_utils/configs/kimi_vl.py create mode 100644 vllm/transformers_utils/configs/moonvit.py diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index ffedd5b0..b6fef2f4 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -886,6 +886,13 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `KimiVLForConditionalGeneration` + * Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking + * T + I+ + * `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` + * + * + * ✅︎ - * `Llama4ForConditionalGeneration` * Llama 4 * T + I+ diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index f51cef95..281d4fbd 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -364,6 +364,29 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData: ) +# Kimi-VL +def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + prompts = [ + "<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>" + f"<|media_pad|><|media_end|>{question}<|im_end|>" + "<|im_assistant|>assistant<|im_middle|>" for question in questions + ] + + engine_args = EngineArgs( + model="moonshotai/Kimi-VL-A3B-Instruct", + max_model_len=4096, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + trust_remote_code=True, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # LLaVA-1.5 def run_llava(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -966,6 +989,7 @@ model_example_map = { "h2ovl_chat": run_h2ovl, "idefics3": run_idefics3, "internvl_chat": run_internvl, + "kimi_vl": run_kimi_vl, "llava": run_llava, "llava-next": run_llava_next, "llava-next-video": run_llava_next_video, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 89818f8b..6fa4a754 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -326,6 +326,45 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "moonshotai/Kimi-VL-A3B-Instruct" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=4, + tensor_parallel_size=1, + limit_mm_per_prompt={"image": len(image_urls)}, + trust_remote_code=True, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [{ + "role": + "user", + "content": [ + *placeholders, + { + "type": "text", + "text": question + }, + ], + }] + + processor = AutoProcessor.from_pretrained(model_name, + trust_remote_code=True) + + prompt = processor.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" @@ -640,6 +679,7 @@ model_example_map = { "h2ovl_chat": load_h2ovl, "idefics3": load_idefics3, "internvl_chat": load_internvl, + "kimi_vl": load_kimi_vl, "llama4": load_llama4, "mistral3": load_mistral3, "mllama": load_mllama, diff --git a/requirements/test.in b/requirements/test.in index 95c94dcd..b9b3df06 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -10,6 +10,7 @@ pytest-timeout # testing utils awscli backoff # required for phi4mm test +blobfile # required for kimi-vl test einops # required for MPT, qwen-vl and Mamba httpx librosa # required for audio tests diff --git a/requirements/test.txt b/requirements/test.txt index 476b4a2c..a5c062b0 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -39,6 +39,8 @@ bitsandbytes==0.45.3 # via -r requirements/test.in black==24.10.0 # via datamodel-code-generator +blobfile==3.0.0 + # via -r requirements/test.in boto3==1.35.57 # via tensorizer botocore==1.35.57 @@ -127,6 +129,7 @@ fastsafetensors==0.1.10 # via -r requirements/test.in filelock==3.16.1 # via + # blobfile # datasets # huggingface-hub # ray @@ -227,7 +230,9 @@ llvmlite==0.44.0 lm-eval==0.4.8 # via -r requirements/test.in lxml==5.3.0 - # via sacrebleu + # via + # blobfile + # sacrebleu markdown-it-py==3.0.0 # via rich markupsafe==3.0.2 @@ -426,6 +431,8 @@ pybind11==2.13.6 # via lm-eval pycparser==2.22 # via cffi +pycryptodomex==3.22.0 + # via blobfile pydantic==2.9.2 # via # datamodel-code-generator @@ -689,6 +696,7 @@ tzdata==2024.2 # via pandas urllib3==2.2.3 # via + # blobfile # botocore # requests # responses diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 5bd10544..5c87cefc 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -318,6 +318,18 @@ VLM_TEST_SETTINGS = { use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, ), + "kimi_vl": VLMTestInfo( + models=["moonshotai/Kimi-VL-A3B-Instruct"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|im_user|>user<|im_middle|>{img_prompt}<|im_end|><|im_assistant|>assistant<|im_middle|>", # noqa: E501 + img_idx_to_prompt=lambda _: "<|media_start|>image<|media_content|><|media_pad|><|media_end|>", # noqa: E501 + max_model_len=8192, + max_num_seqs=2, + dtype="bfloat16", + tensor_parallel_size=1, + vllm_output_post_proc=model_utils.kimiv_vl_vllm_to_hf_output, + marks=[large_gpu_mark(min_gb=48)], + ), "llama4": VLMTestInfo( models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"], prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501 diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 3520345c..49305332 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -68,6 +68,17 @@ def qwen2_vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs +def kimiv_vl_vllm_to_hf_output( + vllm_output: RunnerOutput, + model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + """Sanitize vllm output [kimi_vl models] to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "<|im_end|>[EOS]" + + return output_ids, hf_output_str, out_logprobs + + def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: config = AutoConfig.from_pretrained(model) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 0f25c189..b14e8a02 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -258,6 +258,7 @@ def _test_processing_correctness_mistral( "OpenGVLab/InternVL2-1B", "HuggingFaceM4/Idefics3-8B-Llama3", "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + "moonshotai/Kimi-VL-A3B-Instruct", "meta-llama/Llama-4-Scout-17B-16E-Instruct", "llava-hf/llava-1.5-7b-hf", "llava-hf/llava-v1.6-mistral-7b-hf", diff --git a/tests/models/registry.py b/tests/models/registry.py index 896b6c3b..530da89c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -302,6 +302,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { trust_remote_code=True), "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501 + "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 + extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501 + trust_remote_code=True), "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 min_transformers_version="4.51"), "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 6fb7dc2c..d6010e1c 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -512,6 +512,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): return "<|fim_prefix|><|img|><|fim_suffix|>" if model_type == "gemma3": return "" + if model_type == "kimi_vl": + return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501 raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py new file mode 100644 index 00000000..c2fac70a --- /dev/null +++ b/vllm/model_executor/models/kimi_vl.py @@ -0,0 +1,608 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: E501 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py +# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved. +# +# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL. +# +# Licensing Information: +# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0. +# - Other parts of the code are licensed under the MIT License. +# +# Apache License, Version 2.0: +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License: +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import copy +import math +from collections.abc import Mapping +from dataclasses import dataclass +from typing import (Any, Iterable, List, Literal, Optional, Sequence, Tuple, + TypedDict, Union) + +import torch +from torch import nn +from transformers import BatchFeature +from transformers.activations import GELUActivation + +from vllm.config import VllmConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.moonvit import MoonVitPretrainedModel +from vllm.model_executor.models.utils import merge_multimodal_embeddings +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig +from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config + +from .utils import is_pp_missing_parameter, maybe_prefix + +logger = init_logger(__name__) + + +# For dummy input only +@dataclass +class MaxImageTokenMeta: + width: int = 1024 + height: int = 1024 + + +class KimiVLMultiModalProjector(nn.Module): + + def __init__(self, config: KimiVLConfig): + super().__init__() + + self.hidden_size = (config.vision_config.hidden_size * + config.vision_config.merge_kernel_size[0] * + config.vision_config.merge_kernel_size[1]) + + self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, + eps=1e-5) + self.linear_1 = nn.Linear(self.hidden_size, + self.hidden_size, + bias=True) + self.act = GELUActivation() + self.linear_2 = nn.Linear(self.hidden_size, + config.text_config.hidden_size, + bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(image_features).view( + -1, self.hidden_size) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class KimiVLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape:`(num_patches, num_channels, patch_size, patch_size)` + """ + + image_grid_hws: torch.Tensor + """Shape:`(num_images, 2)`""" + + +# TODO: support embeds too +# We only support pixel input for kimi-vl now +KimiVLImageInputs = KimiVLImagePixelInputs + + +class KimiVLProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(KimiVLConfig) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_processor = self.get_hf_processor() + patch_size = hf_processor.image_processor.patch_size + kernel_size = hf_processor.image_processor.merge_kernel_size + in_token_limit = hf_processor.image_processor.in_token_limit + height = image_height + width = image_width + assert isinstance(height, + int), f"height must be int, current height {height}" + assert isinstance(width, + int), f"width must be int, current width {width}" + assert kernel_size is not None, "kernel_size must be specified" + + if (width // patch_size) * (height // patch_size) > in_token_limit: + scale = math.sqrt(in_token_limit / ((width // patch_size) * + (height // patch_size))) + new_w, new_h = int(width * scale), int(height * scale) + width, height = new_w, new_h + + kernel_height, kernel_width = kernel_size + + pad_height = (kernel_height * patch_size - height % + (kernel_height * patch_size)) % (kernel_height * + patch_size) + pad_width = (kernel_width * patch_size - width % + (kernel_width * patch_size)) % (kernel_width * patch_size) + + # Calculate new dimensions after padding and patching + token_height = (height + pad_height) // (kernel_size[0] * patch_size) + token_width = (width + pad_width) // (kernel_size[1] * patch_size) + return int(token_height * token_width) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + # None means unlimited + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return { + "image": + self.get_num_image_tokens( + image_width=MaxImageTokenMeta.width, + image_height=MaxImageTokenMeta.height, + ), + } + + @property + def image_token_id(self) -> int: + return self.get_hf_config().media_placeholder_token_id + + +class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]): + + def __init__(self, info: KimiVLProcessingInfo) -> None: + super().__init__(info) + + self.image_token_id = self.info.image_token_id + self.image_token = self.info.get_tokenizer().decode( + self.image_token_id) + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + width = MaxImageTokenMeta.width + height = MaxImageTokenMeta.height + mm_data = { + "image": + self._get_dummy_images(width=width, + height=height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text=self.image_token * num_images, + mm_data=mm_data, + ) + + +class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_grid_hws = hf_inputs.get("image_grid_hws", torch.empty((0, 2))) + image_grid_sizes = image_grid_hws.prod(-1) + + # pixel_values is merged as a single large tensor + # image_grid_hws is shapes for each subtensor in pixel_values + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_hws=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + image_token_id = self.info.image_token_id + + def get_replacement(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement, + ), + ] + + +@MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor, + info=KimiVLProcessingInfo, + dummy_inputs=KimiVLDummyInputsBuilder) +class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + model_config = vllm_config.model_config + config: KimiVLConfig = model_config.hf_config + self.config = config + quant_config = vllm_config.quant_config + + assert isinstance(config.vision_config, MoonViTConfig) + + self.vision_tower = MoonVitPretrainedModel(config.vision_config) + + self.multi_modal_projector = KimiVLMultiModalProjector(config=config) + + self.quant_config = quant_config + sub_vllm_config = copy.deepcopy(vllm_config) + sub_vllm_config.model_config.hf_config = sub_vllm_config.model_config.hf_config.text_config + self.language_model = DeepseekV2Model( + vllm_config=sub_vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + self.unpadded_vocab_size = config.text_config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.config.text_config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = get_sampler() + self.media_placeholder: int = self.config.media_placeholder_token_id + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_world_size = get_tensor_model_parallel_world_size() + + # ref: qwen2_vl.py + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return mm_input.reshape(-1, mm_input.shape[-1]) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[KimiVLImageInputs]: + # image input type must be pixel values now + pixel_values = kwargs.pop("pixel_values", None) + image_grid_hws = kwargs.pop("image_grid_hws", None) + + if pixel_values is None: + return None + + image_grid_hws = self._validate_and_reshape_mm_tensor( + image_grid_hws, "image grid hws") + # pixel_values may have complex shapes + num_channels = 3 + patch_size = self.config.vision_config.patch_size + if isinstance(pixel_values, list): + pixel_values = torch.cat([ + x.reshape(-1, num_channels, patch_size, patch_size) + for x in pixel_values + ]) + else: + pixel_values = pixel_values.reshape(-1, num_channels, patch_size, + patch_size) + # fp32 -> bf16 + pixel_values = pixel_values.to(torch.bfloat16) + # image_grid_hws.shape = (N, 2) + assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}" + + return KimiVLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_hws=image_grid_hws, + ) + + # perform vt on processored pixel_values + @torch.inference_mode() + def _process_image_pixels(self, + inputs: KimiVLImagePixelInputs) -> torch.Tensor: + assert self.vision_tower is not None + + pixel_values = inputs["pixel_values"] + image_grid_hws = inputs["image_grid_hws"] + return self.vision_tower(pixel_values, image_grid_hws) + + def _process_image_input(self, + image_input: KimiVLImageInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + image_features = self._process_image_pixels(image_input) + assert isinstance(image_features, list) + lengths = [x.shape[0] for x in image_features] + return self.multi_modal_projector( + torch.cat(image_features)).split(lengths) + + def get_multimodal_embeddings(self, + **kwargs: object) -> Optional[NestedTensors]: + # Validate the multimodal input keyword arguments + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + # Run multimodal inputs through encoder and projector + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + + # `get_input_embeddings` should already be implemented for the language + # model as one of the requirements of basic vLLM model implementation. + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + placeholder_token_id=self.config.media_placeholder_token_id) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> SamplerOutput: + if intermediate_tensors is not None: + inputs_embeds = None + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + inputs_embeds = None + else: + inputs_embeds = self.get_input_embeddings(input_ids) + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config. + media_placeholder_token_id, + ) + input_ids = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + **kwargs) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata, **kwargs) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + config = self.config.text_config + _KEYS_TO_MODIFY_MAPPING = { + "language_model.lm_head": "lm_head", + "language_model.model": "language_model", + } + # only doing this for language model part for now. + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + if not config.use_mla: + stacked_params_mapping += [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + if getattr(config, "n_routed_experts", None): + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=config.n_routed_experts) + else: + expert_params_mapping = [] + + params_dict = dict(self.named_parameters()) + for args in weights: + name, loaded_weight = args[:2] + kwargs = args[2] if len(args) > 2 else {} + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) + use_default_weight_loading = False + if "vision" in name: + if self.vision_tower is not None: + # We only do sharding for language model and + # not vision model for now. + use_default_weight_loading = True + else: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id, **kwargs) + break + else: + for idx, (param_name, weight_name, expert_id, + shard_id) in enumerate(expert_params_mapping): + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + expert_id=expert_id, + shard_id=shard_id, + **kwargs) + break + else: + use_default_weight_loading = True + if use_default_weight_loading: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight, **kwargs) + + +def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py new file mode 100644 index 00000000..c367d90f --- /dev/null +++ b/vllm/model_executor/models/moonvit.py @@ -0,0 +1,628 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: E501 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py +# This file is meant to be used in kimi_vl.py only +# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved. +# +# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL. +# +# Licensing Information: +# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0. +# - Other parts of the code are licensed under the MIT License. +# +# Apache License, Version 2.0: +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License: +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math +from copy import deepcopy +from functools import cached_property +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.activations import ACT2FN, PytorchGELUTanh +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import is_flash_attn_2_available + +from vllm.transformers_utils.configs.moonvit import MoonViTConfig + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func +else: + flash_attn_varlen_func = None + + +def multihead_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_cu_seqlens: Optional[torch.Tensor] = None, + k_cu_seqlens: Optional[torch.Tensor] = None, +): + """Multi-head attention using flash attention 2. + + Args: + q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q. + The first element should be 0 and the last element should be q.shape[0]. + k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k. + The first element should be 0 and the last element should be k.shape[0]. + + Returns: + output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing, + where dim = num_heads * head_dim + """ + # Unified format legal check + assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims" + assert q_cu_seqlens[-1] == q.shape[ + 0], "q_cu_seqlens must sum to q.shape[0]" + assert (k_cu_seqlens[-1] == k.shape[0] == + v.shape[0]), "k_cu_seqlens must sum to k.shape[0]" + assert q.dtype in [ + torch.bfloat16, + torch.float16, + ], f"unsupported dtype {q.dtype} for multihead attn" + + max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item() + max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item() + attn_out = flash_attn_varlen_func( + q, + k, + v, + q_cu_seqlens, + k_cu_seqlens, + max_seqlen_q, + max_seqlen_k, + causal=False, + ) + attn_out = attn_out.flatten(start_dim=-2) + + return attn_out + + +def sdpa_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_cu_seqlens: Optional[torch.Tensor] = None, + k_cu_seqlens: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """SDPA attention. + + Args: + q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + """ + seq_length = q.shape[0] + attention_mask = torch.zeros([1, seq_length, seq_length], + device=q.device, + dtype=torch.bool) + for i in range(1, len(q_cu_seqlens)): + attention_mask[ + ..., + q_cu_seqlens[i - 1]:q_cu_seqlens[i], + q_cu_seqlens[i - 1]:q_cu_seqlens[i], + ] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, + k, + v, + attention_mask, + dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + return attn_output + + +VL_VISION_ATTENTION_FUNCTIONS = { + "flash_attention_2": multihead_attention, + "sdpa": sdpa_attention, +} + + +def _apply_rope_input_validation(x, freqs_cis): + assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape) + assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape) + assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape) + assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype + + +def apply_rope(xq: torch.Tensor, xk: torch.Tensor, + freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: (The leading dimensions of all inputs should be the same) + xq: query, tensor of shape (..., num_heads, head_dim) + xk: key, tensor of shape (..., num_heads, head_dim) + freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid. + Returns: + xq_out, xk_out: tensors of shape (..., num_heads, head_dim) + """ + _apply_rope_input_validation(xq, freqs_cis) + _apply_rope_input_validation(xk, freqs_cis) + + freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2 + # ..., num_heads, head_dim/2 + xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten( + -2) # ..., num_heads, head_dim + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten( + -2) # ..., num_heads, head_dim + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class Learnable2DInterpPosEmb(nn.Module): + + def __init__(self, + height: int, + width: int, + dim: int, + interpolation_mode: str = "bicubic") -> None: + super().__init__() + self.height = height + self.width = width + self.interpolation_mode = interpolation_mode + self.weight = nn.Parameter(torch.empty(height, width, dim)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.weight) + + def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor: + pos_embs = [] + for shape in grid_hws.tolist(): + if shape == self.weight.shape[:-1]: + pos_embs.append(self.weight.flatten(end_dim=1)) + else: + pos_embs.append( + F.interpolate( + self.weight.permute((2, 0, 1)).unsqueeze(0), + size=shape, + mode=self.interpolation_mode, + ).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1)) + out = x + torch.cat(pos_embs) + return out + + +class MoonVisionPatchEmbed(nn.Module): + + def __init__( + self, + out_dim: int, + in_dim: int = 3, + patch_size: Union[int, Tuple[int, int]] = (14, 14), + pos_emb_height: int = 14, + pos_emb_width: int = 14, + ): + super().__init__() + assert isinstance( + patch_size, + (int, Sequence)), f"Invalid patch_size type: {type(patch_size)}" + if isinstance(patch_size, int): + patch_size = (patch_size, patch_size) + assert (len(patch_size) == 2 + ), f"Expected patch_size to be a tuple of 2, got {patch_size}" + self.patch_size = patch_size + + self.proj = nn.Conv2d(in_dim, + out_dim, + kernel_size=patch_size, + stride=patch_size) + + self.pos_emb = Learnable2DInterpPosEmb(height=pos_emb_height, + width=pos_emb_width, + dim=out_dim) + + def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: + """ + Args: + x (L, Channels): input tensor + grid_hw (N, 2): grid height and width + + Returns: + (L, Cout) tensor + """ + x = self.proj(x).view(x.size(0), -1) + # apply positional embedding + x = self.pos_emb(x, grid_hw) + return x + + +class Rope2DPosEmb(nn.Module): + """2D rotary position embedding with multi-resolution support. + + This class is intended to be used in the following way: + 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis. + 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration. + 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation. + The rope is shared across all attention layers and all heads. + + Refs: + - RoFormer: https://arxiv.org/abs/2104.09864 + - VisionLLaMA: https://arxiv.org/abs/2403.00522 + - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py + + Args: + dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed) + max_height (int): the maximum height of the 2D grid + max_width (int): the maximum width of the 2D grid + theta_base (float): the base of the theta + device (str): the device to store the precomputed cis + """ + + def __init__(self, + dim: int, + max_height: int, + max_width: int, + theta_base=10000, + device="cuda"): + super().__init__() + self.dim = dim + assert self.dim % 4 == 0, "dim must be divisible by 4" + self.max_height = max_height + self.max_width = max_width + self.theta_base = theta_base + self.device = device + + def extra_repr(self): + return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}" + + @cached_property + def precomputed_freqs_cis(self) -> torch.Tensor: + """Calculate the cis(freqs) for each position in the 2D grid. + + Return: complex tensor of shape (max_height, max_width, dim//2) and value: + height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim)) + weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4)) + note: `cis` is a mathematical notation defined by cis x = cos x + i sin x, + """ + N = self.max_height * self.max_width + flat_pos = torch.arange(0, N).float().to(self.device) + x_pos = flat_pos % self.max_width + y_pos = flat_pos // self.max_width + dim_range = (torch.arange(0, self.dim, + 4)[:(self.dim // 4)].float().to(self.device) + ) # C/4 + freqs = 1.0 / (self.theta_base**(dim_range / self.dim)) + x_freqs = torch.outer(x_pos, freqs).float() # N, C/4 + y_freqs = torch.outer(y_pos, freqs).float() # N, C/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4 + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4 + # N, C/4, 2 + freqs_cis = torch.cat( + [x_cis.unsqueeze(dim=-1), + y_cis.unsqueeze(dim=-1)], dim=-1) + # max_height, max_width, C/2 + freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1) + return freqs_cis + + def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor: + """ + Args: + grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples. + Returns: + freqs_cis: tensor of shape (sum(t * height * width), dim//2) + """ + shapes = grid_hws.tolist() + assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width + for h, w in shapes), ( + shapes, + self.max_height, + self.max_width, + ) + freqs_cis = torch.cat( + [ + self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2) + for h, w in shapes + ], + dim=0, + ) + return freqs_cis + + def get_freqs_cis_by_idx(self, pos_idx: torch.Tensor, + pos_idx_mask: torch.Tensor) -> torch.Tensor: + """ + Args: + pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token. + pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx. + Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones. + Return: + freqs_cis: tensor of shape (..., dim//2) + """ + assert (pos_idx.shape[:-1] == pos_idx_mask.shape + and pos_idx.shape[-1] == 2 and pos_idx.ndim + == pos_idx_mask.ndim + 1), (pos_idx.shape, pos_idx_mask.shape) + assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype + + shp = pos_idx_mask.shape + (self.dim // 2, ) # ..., head_dim/2 + freqs_cis = torch.ones(shp, dtype=torch.complex64, + device=self.device) # ..., head_dim/2 + freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[pos_idx[ + ..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]] + return freqs_cis + + +class MLP2(nn.Module): + """ + Args: + dims: [in_dim, hidden_dim, out_dim] + bias: whether to use bias in linear layer. + """ + + def __init__(self, dims: list[int], activation, bias=True): + super().__init__() + assert len(dims) == 3 + self.fc0 = nn.Linear(dims[0], dims[1], bias=bias) + self.fc1 = nn.Linear(dims[1], dims[2], bias=bias) + self.activation = activation + for m in [self.fc0, self.fc1]: + nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features)) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc0(x) + x = self.activation(x) + return self.fc1(x) + + +class MoonVitEncoderLayer(nn.Module): + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + *, + attn_implementation: str = "sdpa", + activation=F.gelu, + attn_bias: bool = False, + ): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads + self.attn_implementation = attn_implementation + # use fa2 in vllm by default + if is_flash_attn_2_available(): + self.attn_implementation = "flash_attention_2" + + self.norm0 = nn.LayerNorm(hidden_dim) + self.norm1 = nn.LayerNorm(hidden_dim) + self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation) + self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias) + self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias) + + def attention_qkvpacked( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rope_freqs_cis: Optional[torch.Tensor] = None, + ): + """ + Args: + x (torch.Tensor): (batch_size, seqlen, hidden_dim) + cu_seqlens (torch.Tensor): + """ + xqkv = self.wqkv(x) + + qkv_shape = xqkv.size()[:-1] + ( + 3, + self.num_heads, + self.hidden_size_per_attention_head, + ) + # xqkv: (batch_size, seqlen, 3, nheads, headdim) + xqkv = xqkv.view(*qkv_shape) + xq, xk, xv = torch.unbind(xqkv, dim=-3) + + xq, xk = apply_rope(xq, xk, rope_freqs_cis) + + attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation] + attn_out = attn_func(xq, + xk, + xv, + q_cu_seqlens=cu_seqlens, + k_cu_seqlens=cu_seqlens) + + attn_out = self.wo(attn_out) + return attn_out + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rope_freqs_cis: Union[torch.Tensor, None] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set + + Returns: + output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input + """ + residual = hidden_states + hidden_states = self.norm0(hidden_states) + attn_out = self.attention_qkvpacked(hidden_states, + cu_seqlens, + rope_freqs_cis=rope_freqs_cis) + hidden_states = residual + attn_out + + residual = hidden_states + hidden_states = self.mlp(self.norm1(hidden_states)) + hidden_states = residual + hidden_states + return hidden_states + + +class MoonVitEncoder(nn.Module): + + def __init__( + self, + hidden_dim: int, + num_layers: int, + block_cfg: dict, + ) -> None: + super().__init__() + + self.rope_2d = Rope2DPosEmb( + block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512) + self.blocks = nn.ModuleList( + [MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)]) + self.final_layernorm = nn.LayerNorm(hidden_dim) + + def forward(self, hidden_states: torch.Tensor, + grid_hw: torch.Tensor) -> torch.Tensor: + rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens( + grid_hws=grid_hw) + + lengths = torch.cat(( + torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype), + grid_hw[:, 0] * grid_hw[:, 1], + )) + cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32) + + for _, block in enumerate(self.blocks): + hidden_states = block(hidden_states, + cu_seqlens, + rope_freqs_cis=rope_freqs_cis) + + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +def patch_merger( + x: torch.Tensor, + grid_hw: torch.Tensor, + merge_kernel_size: list[int, int] = (2, 2), +) -> List[torch.Tensor]: + d_model = x.size(-1) + + outputs = [] + pre_sum = 0 + for x_shape in grid_hw.tolist(): + height, width = x_shape[0], x_shape[1] + # Get the current sequence + seq = x[pre_sum:pre_sum + height * width] + # Reshape along self.merge_kernel_size and concat to the last dimension + kernel_height, kernel_width = merge_kernel_size + new_height, new_width = height // kernel_height, width // kernel_width + reshaped_seq = seq.view(new_height, kernel_height, new_width, + kernel_width, d_model) + reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous() + padded_seq = reshaped_seq.view(new_height * new_width, + kernel_height * kernel_width, -1) + outputs.append(padded_seq) + pre_sum += height * width + + return outputs + + +class MoonVitVLProjector(nn.Module): + + def __init__( + self, + in_channels: int, + merge_kernel_size: list[int, int], + hidden_act: str = "gelu", + ln_eps: float = 1e-5, + out_dim: int = 4096, + ): + super().__init__() + self.hidden_size = in_channels * merge_kernel_size[ + 0] * merge_kernel_size[1] + + self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps) + self.linear_1 = nn.Linear(self.hidden_size, + self.hidden_size, + bias=True) + self.act = ACT2FN[hidden_act] + self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class MoonVitPretrainedModel(PreTrainedModel): + config_class = MoonViTConfig + model_type = "moonvit" + _no_split_modules = ["PackingTransformer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__(self, config: MoonViTConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + config = deepcopy(config) + self.merge_kernel_size = config.merge_kernel_size + self.patch_size = config.patch_size + self.patch_embed = MoonVisionPatchEmbed( + out_dim=config.hidden_size, + patch_size=config.patch_size, + pos_emb_height=config.init_pos_emb_height, + pos_emb_width=config.init_pos_emb_width, + ) + + self.encoder = MoonVitEncoder( + hidden_dim=config.hidden_size, + num_layers=config.num_hidden_layers, + block_cfg={ + "num_heads": config.num_attention_heads, + "hidden_dim": config.hidden_size, + "mlp_dim": config.intermediate_size, + "activation": PytorchGELUTanh(), + "attn_bias": True, + "attn_implementation": config._attn_implementation, + }, + ) + + def forward(self, pixel_values: torch.Tensor, + grid_hw: torch.Tensor) -> torch.Tensor: + """ + Args: + pixel_values (torch.Tensor): The input pixel values. + grid_hw (torch.Tensor): The grid height and width. + + Returns: + torch.Tensor: The output tokens. + """ + hidden_states = self.patch_embed(pixel_values, grid_hw) + hidden_states = self.encoder(hidden_states, grid_hw) + hidden_states = patch_merger(hidden_states, + grid_hw, + merge_kernel_size=self.merge_kernel_size) + return hidden_states diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 0d13d699..b345113e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -177,6 +177,7 @@ _MULTIMODAL_MODELS = { "InternVLChatModel": ("internvl", "InternVLChatModel"), "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 + "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index fe0319c9..f37605be 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -33,12 +33,13 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, EAGLEConfig, ExaoneConfig, H2OVLChatConfig, InternVLChatConfig, JAISConfig, - MedusaConfig, MllamaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, NVLM_D_Config, - Olmo2Config, RWConfig, - SkyworkR1VChatConfig, SolarConfig, - Telechat2Config, UltravoxConfig) + KimiVLConfig, MedusaConfig, + MllamaConfig, MLPSpeculatorConfig, + MPTConfig, NemotronConfig, + NVLM_D_Config, Olmo2Config, + RWConfig, SkyworkR1VChatConfig, + SolarConfig, Telechat2Config, + UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import resolve_obj_by_qualname @@ -62,6 +63,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { "cohere2": Cohere2Config, "dbrx": DbrxConfig, "deepseek_vl_v2": DeepseekVLV2Config, + "kimi_vl": KimiVLConfig, "mpt": MPTConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 53699341..739eea5c 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -13,9 +13,11 @@ from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig +from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm.transformers_utils.configs.mllama import MllamaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig +from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config @@ -40,6 +42,8 @@ __all__ = [ "ExaoneConfig", "MllamaConfig", "MLPSpeculatorConfig", + "MoonViTConfig", + "KimiVLConfig", "NemotronConfig", "NVLM_D_Config", "Olmo2Config", diff --git a/vllm/transformers_utils/configs/kimi_vl.py b/vllm/transformers_utils/configs/kimi_vl.py new file mode 100644 index 00000000..97ff44bb --- /dev/null +++ b/vllm/transformers_utils/configs/kimi_vl.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py +from typing import Optional, Union + +from transformers.configuration_utils import PretrainedConfig + +from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config +from vllm.transformers_utils.configs.moonvit import MoonViTConfig + + +class KimiVLConfig(PretrainedConfig): + model_type = "kimi_vl" + + def __init__(self, + vision_config: Optional[Union[dict, MoonViTConfig]] = None, + text_config: Optional[Union[dict, DeepseekV2Config]] = None, + ignore_index: int = -100, + media_placeholder_token_id: int = 163605, + pad_token_id: int = 0, + **kwargs): + if vision_config is None: + vision_config = MoonViTConfig() + elif isinstance(vision_config, dict): + vision_config = MoonViTConfig(**vision_config) + self.vision_config = vision_config + + if text_config is None: + text_config = DeepseekV2Config() + elif isinstance(text_config, dict): + text_config = DeepseekV2Config(**text_config) + self.text_config = text_config + + self.ignore_index = ignore_index + self.media_placeholder_token_id = media_placeholder_token_id + + super().__init__(pad_token_id=pad_token_id, **kwargs) diff --git a/vllm/transformers_utils/configs/moonvit.py b/vllm/transformers_utils/configs/moonvit.py new file mode 100644 index 00000000..a2b4059a --- /dev/null +++ b/vllm/transformers_utils/configs/moonvit.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py +from transformers.configuration_utils import PretrainedConfig + + +class MoonViTConfig(PretrainedConfig): + model_type = "moonvit" + + def __init__( + self, + patch_size: int = 14, + init_pos_emb_height: int = 64, + init_pos_emb_width: int = 64, + num_attention_heads: int = 16, + num_hidden_layers: int = 27, + hidden_size: int = 1152, + intermediate_size: int = 4304, + merge_kernel_size: tuple[int, int] = (2, 2), + **kwargs, + ): + super().__init__(**kwargs) + self.patch_size = patch_size + # Positional embedding config + self.init_pos_emb_height = init_pos_emb_height + self.init_pos_emb_width = init_pos_emb_width + # Transformer config + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + # Patch merger config + self.merge_kernel_size = merge_kernel_size diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 70e8bd75..c3d84ab3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -997,13 +997,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Return empty ModelRunnerOutput if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT - if self.is_multimodal_model: - # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) - else: - mm_embeds = [] - # Prepare the decoder inputs. attn_metadata, logits_indices, spec_decode_metadata = ( self._prepare_inputs(scheduler_output)) @@ -1019,6 +1012,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_input_tokens = num_scheduled_tokens attn_metadata.num_input_tokens = num_input_tokens + # _prepare_inputs may reorder the batch, so we must gather multi + # modal outputs after that to ensure the correct order + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + else: + mm_embeds = [] + if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids)