diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index ffedd5b0..b6fef2f4 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -886,6 +886,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
+- * `KimiVLForConditionalGeneration`
+ * Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking
+ * T + I+
+ * `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking`
+ *
+ *
+ * ✅︎
- * `Llama4ForConditionalGeneration`
* Llama 4
* T + I+
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index f51cef95..281d4fbd 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -364,6 +364,29 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
)
+# Kimi-VL
+def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
+ assert modality == "image"
+
+ prompts = [
+ "<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>"
+ f"<|media_pad|><|media_end|>{question}<|im_end|>"
+ "<|im_assistant|>assistant<|im_middle|>" for question in questions
+ ]
+
+ engine_args = EngineArgs(
+ model="moonshotai/Kimi-VL-A3B-Instruct",
+ max_model_len=4096,
+ disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
+ trust_remote_code=True,
+ )
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+
# LLaVA-1.5
def run_llava(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@@ -966,6 +989,7 @@ model_example_map = {
"h2ovl_chat": run_h2ovl,
"idefics3": run_idefics3,
"internvl_chat": run_internvl,
+ "kimi_vl": run_kimi_vl,
"llava": run_llava,
"llava-next": run_llava_next,
"llava-next-video": run_llava_next_video,
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index 89818f8b..6fa4a754 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -326,6 +326,45 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
)
+def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
+ model_name = "moonshotai/Kimi-VL-A3B-Instruct"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=4096,
+ max_num_seqs=4,
+ tensor_parallel_size=1,
+ limit_mm_per_prompt={"image": len(image_urls)},
+ trust_remote_code=True,
+ )
+
+ placeholders = [{"type": "image", "image": url} for url in image_urls]
+ messages = [{
+ "role":
+ "user",
+ "content": [
+ *placeholders,
+ {
+ "type": "text",
+ "text": question
+ },
+ ],
+ }]
+
+ processor = AutoProcessor.from_pretrained(model_name,
+ trust_remote_code=True)
+
+ prompt = processor.apply_chat_template(messages,
+ tokenize=False,
+ add_generation_prompt=True)
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompt=prompt,
+ image_data=[fetch_image(url) for url in image_urls],
+ )
+
+
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
@@ -640,6 +679,7 @@ model_example_map = {
"h2ovl_chat": load_h2ovl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
+ "kimi_vl": load_kimi_vl,
"llama4": load_llama4,
"mistral3": load_mistral3,
"mllama": load_mllama,
diff --git a/requirements/test.in b/requirements/test.in
index 95c94dcd..b9b3df06 100644
--- a/requirements/test.in
+++ b/requirements/test.in
@@ -10,6 +10,7 @@ pytest-timeout
# testing utils
awscli
backoff # required for phi4mm test
+blobfile # required for kimi-vl test
einops # required for MPT, qwen-vl and Mamba
httpx
librosa # required for audio tests
diff --git a/requirements/test.txt b/requirements/test.txt
index 476b4a2c..a5c062b0 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -39,6 +39,8 @@ bitsandbytes==0.45.3
# via -r requirements/test.in
black==24.10.0
# via datamodel-code-generator
+blobfile==3.0.0
+ # via -r requirements/test.in
boto3==1.35.57
# via tensorizer
botocore==1.35.57
@@ -127,6 +129,7 @@ fastsafetensors==0.1.10
# via -r requirements/test.in
filelock==3.16.1
# via
+ # blobfile
# datasets
# huggingface-hub
# ray
@@ -227,7 +230,9 @@ llvmlite==0.44.0
lm-eval==0.4.8
# via -r requirements/test.in
lxml==5.3.0
- # via sacrebleu
+ # via
+ # blobfile
+ # sacrebleu
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2
@@ -426,6 +431,8 @@ pybind11==2.13.6
# via lm-eval
pycparser==2.22
# via cffi
+pycryptodomex==3.22.0
+ # via blobfile
pydantic==2.9.2
# via
# datamodel-code-generator
@@ -689,6 +696,7 @@ tzdata==2024.2
# via pandas
urllib3==2.2.3
# via
+ # blobfile
# botocore
# requests
# responses
diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index 5bd10544..5c87cefc 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -318,6 +318,18 @@ VLM_TEST_SETTINGS = {
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
),
+ "kimi_vl": VLMTestInfo(
+ models=["moonshotai/Kimi-VL-A3B-Instruct"],
+ test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
+ prompt_formatter=lambda img_prompt: f"<|im_user|>user<|im_middle|>{img_prompt}<|im_end|><|im_assistant|>assistant<|im_middle|>", # noqa: E501
+ img_idx_to_prompt=lambda _: "<|media_start|>image<|media_content|><|media_pad|><|media_end|>", # noqa: E501
+ max_model_len=8192,
+ max_num_seqs=2,
+ dtype="bfloat16",
+ tensor_parallel_size=1,
+ vllm_output_post_proc=model_utils.kimiv_vl_vllm_to_hf_output,
+ marks=[large_gpu_mark(min_gb=48)],
+ ),
"llama4": VLMTestInfo(
models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"],
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501
diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
index 3520345c..49305332 100644
--- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
+++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
@@ -68,6 +68,17 @@ def qwen2_vllm_to_hf_output(
return output_ids, hf_output_str, out_logprobs
+def kimiv_vl_vllm_to_hf_output(
+ vllm_output: RunnerOutput,
+ model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]:
+ """Sanitize vllm output [kimi_vl models] to be comparable with hf output."""
+ output_ids, output_str, out_logprobs = vllm_output
+
+ hf_output_str = output_str + "<|im_end|>[EOS]"
+
+ return output_ids, hf_output_str, out_logprobs
+
+
def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput,
model: str) -> RunnerOutput:
config = AutoConfig.from_pretrained(model)
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index 0f25c189..b14e8a02 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -258,6 +258,7 @@ def _test_processing_correctness_mistral(
"OpenGVLab/InternVL2-1B",
"HuggingFaceM4/Idefics3-8B-Llama3",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
+ "moonshotai/Kimi-VL-A3B-Instruct",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf",
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 896b6c3b..530da89c 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -302,6 +302,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True),
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
+ "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
+ extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
+ trust_remote_code=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
min_transformers_version="4.51"),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 6fb7dc2c..d6010e1c 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -512,6 +512,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<|fim_prefix|><|img|><|fim_suffix|>"
if model_type == "gemma3":
return ""
+ if model_type == "kimi_vl":
+ return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py
new file mode 100644
index 00000000..c2fac70a
--- /dev/null
+++ b/vllm/model_executor/models/kimi_vl.py
@@ -0,0 +1,608 @@
+# SPDX-License-Identifier: Apache-2.0
+# ruff: noqa: E501
+# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
+# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
+#
+# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
+#
+# Licensing Information:
+# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
+# - Other parts of the code are licensed under the MIT License.
+#
+# Apache License, Version 2.0:
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# MIT License:
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import copy
+import math
+from collections.abc import Mapping
+from dataclasses import dataclass
+from typing import (Any, Iterable, List, Literal, Optional, Sequence, Tuple,
+ TypedDict, Union)
+
+import torch
+from torch import nn
+from transformers import BatchFeature
+from transformers.activations import GELUActivation
+
+from vllm.config import VllmConfig
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader, maybe_remap_kv_scale_name)
+from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
+from vllm.model_executor.models.interfaces import SupportsMultiModal
+from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
+from vllm.model_executor.models.utils import merge_multimodal_embeddings
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
+ NestedTensors)
+from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
+ MultiModalDataItems)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement,
+ PromptUpdate)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
+from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
+
+from .utils import is_pp_missing_parameter, maybe_prefix
+
+logger = init_logger(__name__)
+
+
+# For dummy input only
+@dataclass
+class MaxImageTokenMeta:
+ width: int = 1024
+ height: int = 1024
+
+
+class KimiVLMultiModalProjector(nn.Module):
+
+ def __init__(self, config: KimiVLConfig):
+ super().__init__()
+
+ self.hidden_size = (config.vision_config.hidden_size *
+ config.vision_config.merge_kernel_size[0] *
+ config.vision_config.merge_kernel_size[1])
+
+ self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size,
+ eps=1e-5)
+ self.linear_1 = nn.Linear(self.hidden_size,
+ self.hidden_size,
+ bias=True)
+ self.act = GELUActivation()
+ self.linear_2 = nn.Linear(self.hidden_size,
+ config.text_config.hidden_size,
+ bias=True)
+
+ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.pre_norm(image_features).view(
+ -1, self.hidden_size)
+ hidden_states = self.linear_1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class KimiVLImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ pixel_values: Union[torch.Tensor, List[torch.Tensor]]
+ """
+ Shape:`(num_patches, num_channels, patch_size, patch_size)`
+ """
+
+ image_grid_hws: torch.Tensor
+ """Shape:`(num_images, 2)`"""
+
+
+# TODO: support embeds too
+# We only support pixel input for kimi-vl now
+KimiVLImageInputs = KimiVLImagePixelInputs
+
+
+class KimiVLProcessingInfo(BaseProcessingInfo):
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(KimiVLConfig)
+
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ ) -> int:
+ hf_processor = self.get_hf_processor()
+ patch_size = hf_processor.image_processor.patch_size
+ kernel_size = hf_processor.image_processor.merge_kernel_size
+ in_token_limit = hf_processor.image_processor.in_token_limit
+ height = image_height
+ width = image_width
+ assert isinstance(height,
+ int), f"height must be int, current height {height}"
+ assert isinstance(width,
+ int), f"width must be int, current width {width}"
+ assert kernel_size is not None, "kernel_size must be specified"
+
+ if (width // patch_size) * (height // patch_size) > in_token_limit:
+ scale = math.sqrt(in_token_limit / ((width // patch_size) *
+ (height // patch_size)))
+ new_w, new_h = int(width * scale), int(height * scale)
+ width, height = new_w, new_h
+
+ kernel_height, kernel_width = kernel_size
+
+ pad_height = (kernel_height * patch_size - height %
+ (kernel_height * patch_size)) % (kernel_height *
+ patch_size)
+ pad_width = (kernel_width * patch_size - width %
+ (kernel_width * patch_size)) % (kernel_width * patch_size)
+
+ # Calculate new dimensions after padding and patching
+ token_height = (height + pad_height) // (kernel_size[0] * patch_size)
+ token_width = (width + pad_width) // (kernel_size[1] * patch_size)
+ return int(token_height * token_width)
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ # None means unlimited
+ return {"image": None}
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ return {
+ "image":
+ self.get_num_image_tokens(
+ image_width=MaxImageTokenMeta.width,
+ image_height=MaxImageTokenMeta.height,
+ ),
+ }
+
+ @property
+ def image_token_id(self) -> int:
+ return self.get_hf_config().media_placeholder_token_id
+
+
+class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]):
+
+ def __init__(self, info: KimiVLProcessingInfo) -> None:
+ super().__init__(info)
+
+ self.image_token_id = self.info.image_token_id
+ self.image_token = self.info.get_tokenizer().decode(
+ self.image_token_id)
+
+ def get_dummy_processor_inputs(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ num_images = mm_counts.get("image", 0)
+
+ width = MaxImageTokenMeta.width
+ height = MaxImageTokenMeta.height
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=width,
+ height=height,
+ num_images=num_images)
+ }
+
+ return ProcessorInputs(
+ prompt_text=self.image_token * num_images,
+ mm_data=mm_data,
+ )
+
+
+class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ image_grid_hws = hf_inputs.get("image_grid_hws", torch.empty((0, 2)))
+ image_grid_sizes = image_grid_hws.prod(-1)
+
+ # pixel_values is merged as a single large tensor
+ # image_grid_hws is shapes for each subtensor in pixel_values
+ return dict(
+ pixel_values=MultiModalFieldConfig.flat_from_sizes(
+ "image", image_grid_sizes),
+ image_grid_hws=MultiModalFieldConfig.batched("image"),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> Sequence[PromptUpdate]:
+ image_token_id = self.info.image_token_id
+
+ def get_replacement(item_idx: int):
+ images = mm_items.get_items(
+ "image", (ImageEmbeddingItems, ImageProcessorItems))
+
+ if isinstance(images, ImageEmbeddingItems):
+ num_image_tokens = images.get_feature_size(item_idx)
+ else:
+ image_size = images.get_image_size(item_idx)
+ num_image_tokens = self.info.get_num_image_tokens(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ )
+
+ return [image_token_id] * num_image_tokens
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=[image_token_id],
+ replacement=get_replacement,
+ ),
+ ]
+
+
+@MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor,
+ info=KimiVLProcessingInfo,
+ dummy_inputs=KimiVLDummyInputsBuilder)
+class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
+
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ model_config = vllm_config.model_config
+ config: KimiVLConfig = model_config.hf_config
+ self.config = config
+ quant_config = vllm_config.quant_config
+
+ assert isinstance(config.vision_config, MoonViTConfig)
+
+ self.vision_tower = MoonVitPretrainedModel(config.vision_config)
+
+ self.multi_modal_projector = KimiVLMultiModalProjector(config=config)
+
+ self.quant_config = quant_config
+ sub_vllm_config = copy.deepcopy(vllm_config)
+ sub_vllm_config.model_config.hf_config = sub_vllm_config.model_config.hf_config.text_config
+ self.language_model = DeepseekV2Model(
+ vllm_config=sub_vllm_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ )
+ self.unpadded_vocab_size = config.text_config.vocab_size
+ self.lm_head = ParallelLMHead(
+ self.unpadded_vocab_size,
+ config.text_config.hidden_size,
+ org_num_embeddings=self.config.text_config.vocab_size,
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE)
+ logit_scale = getattr(config, "logit_scale", 1.0)
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+ config.vocab_size, logit_scale)
+ self.sampler = get_sampler()
+ self.media_placeholder: int = self.config.media_placeholder_token_id
+ self.tp_rank = get_tensor_model_parallel_rank()
+ self.tp_world_size = get_tensor_model_parallel_world_size()
+
+ # ref: qwen2_vl.py
+ def _validate_and_reshape_mm_tensor(self, mm_input: object,
+ name: str) -> torch.Tensor:
+ if not isinstance(mm_input, (torch.Tensor, list)):
+ raise ValueError(f"Incorrect type of {name}. "
+ f"Got type: {type(mm_input)}")
+ if isinstance(mm_input, torch.Tensor):
+ if mm_input.ndim == 2:
+ return mm_input
+ if mm_input.ndim != 3:
+ raise ValueError(f"{name} should be 2D or batched 3D tensor. "
+ f"Got ndim: {mm_input.ndim} "
+ f"(shape={mm_input.shape})")
+ return mm_input.reshape(-1, mm_input.shape[-1])
+ else:
+ return torch.concat(mm_input)
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[KimiVLImageInputs]:
+ # image input type must be pixel values now
+ pixel_values = kwargs.pop("pixel_values", None)
+ image_grid_hws = kwargs.pop("image_grid_hws", None)
+
+ if pixel_values is None:
+ return None
+
+ image_grid_hws = self._validate_and_reshape_mm_tensor(
+ image_grid_hws, "image grid hws")
+ # pixel_values may have complex shapes
+ num_channels = 3
+ patch_size = self.config.vision_config.patch_size
+ if isinstance(pixel_values, list):
+ pixel_values = torch.cat([
+ x.reshape(-1, num_channels, patch_size, patch_size)
+ for x in pixel_values
+ ])
+ else:
+ pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
+ patch_size)
+ # fp32 -> bf16
+ pixel_values = pixel_values.to(torch.bfloat16)
+ # image_grid_hws.shape = (N, 2)
+ assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}"
+
+ return KimiVLImagePixelInputs(
+ type="pixel_values",
+ pixel_values=pixel_values,
+ image_grid_hws=image_grid_hws,
+ )
+
+ # perform vt on processored pixel_values
+ @torch.inference_mode()
+ def _process_image_pixels(self,
+ inputs: KimiVLImagePixelInputs) -> torch.Tensor:
+ assert self.vision_tower is not None
+
+ pixel_values = inputs["pixel_values"]
+ image_grid_hws = inputs["image_grid_hws"]
+ return self.vision_tower(pixel_values, image_grid_hws)
+
+ def _process_image_input(self,
+ image_input: KimiVLImageInputs) -> torch.Tensor:
+ assert image_input["type"] == "pixel_values"
+ image_features = self._process_image_pixels(image_input)
+ assert isinstance(image_features, list)
+ lengths = [x.shape[0] for x in image_features]
+ return self.multi_modal_projector(
+ torch.cat(image_features)).split(lengths)
+
+ def get_multimodal_embeddings(self,
+ **kwargs: object) -> Optional[NestedTensors]:
+ # Validate the multimodal input keyword arguments
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ if image_input is None:
+ return None
+
+ # Run multimodal inputs through encoder and projector
+ vision_embeddings = self._process_image_input(image_input)
+ return vision_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[NestedTensors] = None,
+ ) -> torch.Tensor:
+
+ # `get_input_embeddings` should already be implemented for the language
+ # model as one of the requirements of basic vLLM model implementation.
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+
+ if multimodal_embeddings is not None:
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ multimodal_embeddings=multimodal_embeddings,
+ placeholder_token_id=self.config.media_placeholder_token_id)
+
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> SamplerOutput:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+ # NOTE: In v1, inputs_embeds is always generated at model runner from
+ # `get_multimodal_embeddings` and `get_input_embeddings`, this
+ # condition is only for v0 compatibility.
+ elif inputs_embeds is None:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ if image_input is None:
+ inputs_embeds = None
+ else:
+ inputs_embeds = self.get_input_embeddings(input_ids)
+ image_embeds = self._process_image_input(image_input)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ image_embeds,
+ placeholder_token_id=self.config.
+ media_placeholder_token_id,
+ )
+ input_ids = None
+
+ hidden_states = self.language_model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+
+ return hidden_states
+
+ def compute_logits(self, hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ **kwargs) -> torch.Tensor:
+ logits = self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata, **kwargs)
+ return logits
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ return next_tokens
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ config = self.config.text_config
+ _KEYS_TO_MODIFY_MAPPING = {
+ "language_model.lm_head": "lm_head",
+ "language_model.model": "language_model",
+ }
+ # only doing this for language model part for now.
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".gate_up_proj", ".gate_proj", 0),
+ (".gate_up_proj", ".up_proj", 1),
+ ]
+ if not config.use_mla:
+ stacked_params_mapping += [
+ (".qkv_proj", ".q_proj", "q"),
+ (".qkv_proj", ".k_proj", "k"),
+ (".qkv_proj", ".v_proj", "v"),
+ ]
+ if getattr(config, "n_routed_experts", None):
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=config.n_routed_experts)
+ else:
+ expert_params_mapping = []
+
+ params_dict = dict(self.named_parameters())
+ for args in weights:
+ name, loaded_weight = args[:2]
+ kwargs = args[2] if len(args) > 2 else {}
+ if "rotary_emb.inv_freq" in name:
+ continue
+
+ spec_layer = get_spec_layer_idx_from_weight_name(config, name)
+ if spec_layer is not None:
+ continue # skip spec decode layers for main model
+
+ if ("rotary_emb.cos_cached" in name
+ or "rotary_emb.sin_cached" in name):
+ # Models trained using ColossalAI may include these tensors in
+ # the checkpoint. Skip them.
+ continue
+ for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
+ if key_to_modify in name:
+ name = name.replace(key_to_modify, new_key)
+ use_default_weight_loading = False
+ if "vision" in name:
+ if self.vision_tower is not None:
+ # We only do sharding for language model and
+ # not vision model for now.
+ use_default_weight_loading = True
+ else:
+ for (param_name, weight_name,
+ shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ # We have mlp.experts[0].gate_proj in the checkpoint.
+ # Since we handle the experts below in expert_params_mapping,
+ # we need to skip here BEFORE we update the name, otherwise
+ # name will be updated to mlp.experts[0].gate_up_proj, which
+ # will then be updated below in expert_params_mapping
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+ if (("mlp.experts." in name) and name not in params_dict):
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id, **kwargs)
+ break
+ else:
+ for idx, (param_name, weight_name, expert_id,
+ shard_id) in enumerate(expert_params_mapping):
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param,
+ loaded_weight,
+ name,
+ expert_id=expert_id,
+ shard_id=shard_id,
+ **kwargs)
+ break
+ else:
+ use_default_weight_loading = True
+ if use_default_weight_loading:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight, **kwargs)
+
+
+def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config,
+ weight_name: str) -> Optional[int]:
+ if hasattr(config,
+ "num_nextn_predict_layers") and (config.num_nextn_predict_layers
+ > 0):
+ layer_idx = config.num_hidden_layers
+ for i in range(config.num_nextn_predict_layers):
+ if weight_name.startswith(f"model.layers.{layer_idx+i}."):
+ return layer_idx + i
+ return None
diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py
new file mode 100644
index 00000000..c367d90f
--- /dev/null
+++ b/vllm/model_executor/models/moonvit.py
@@ -0,0 +1,628 @@
+# SPDX-License-Identifier: Apache-2.0
+# ruff: noqa: E501
+# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
+# This file is meant to be used in kimi_vl.py only
+# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
+#
+# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
+#
+# Licensing Information:
+# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
+# - Other parts of the code are licensed under the MIT License.
+#
+# Apache License, Version 2.0:
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# MIT License:
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import math
+from copy import deepcopy
+from functools import cached_property
+from typing import List, Optional, Sequence, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.activations import ACT2FN, PytorchGELUTanh
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import is_flash_attn_2_available
+
+from vllm.transformers_utils.configs.moonvit import MoonViTConfig
+
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_varlen_func
+else:
+ flash_attn_varlen_func = None
+
+
+def multihead_attention(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ q_cu_seqlens: Optional[torch.Tensor] = None,
+ k_cu_seqlens: Optional[torch.Tensor] = None,
+):
+ """Multi-head attention using flash attention 2.
+
+ Args:
+ q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
+ or (tot_seqlens, num_heads, head_dim) if packing.
+ q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
+ The first element should be 0 and the last element should be q.shape[0].
+ k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
+ The first element should be 0 and the last element should be k.shape[0].
+
+ Returns:
+ output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
+ where dim = num_heads * head_dim
+ """
+ # Unified format legal check
+ assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
+ assert q_cu_seqlens[-1] == q.shape[
+ 0], "q_cu_seqlens must sum to q.shape[0]"
+ assert (k_cu_seqlens[-1] == k.shape[0] ==
+ v.shape[0]), "k_cu_seqlens must sum to k.shape[0]"
+ assert q.dtype in [
+ torch.bfloat16,
+ torch.float16,
+ ], f"unsupported dtype {q.dtype} for multihead attn"
+
+ max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
+ max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
+ attn_out = flash_attn_varlen_func(
+ q,
+ k,
+ v,
+ q_cu_seqlens,
+ k_cu_seqlens,
+ max_seqlen_q,
+ max_seqlen_k,
+ causal=False,
+ )
+ attn_out = attn_out.flatten(start_dim=-2)
+
+ return attn_out
+
+
+def sdpa_attention(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ q_cu_seqlens: Optional[torch.Tensor] = None,
+ k_cu_seqlens: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ """SDPA attention.
+
+ Args:
+ q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
+ or (tot_seqlens, num_heads, head_dim) if packing.
+ """
+ seq_length = q.shape[0]
+ attention_mask = torch.zeros([1, seq_length, seq_length],
+ device=q.device,
+ dtype=torch.bool)
+ for i in range(1, len(q_cu_seqlens)):
+ attention_mask[
+ ...,
+ q_cu_seqlens[i - 1]:q_cu_seqlens[i],
+ q_cu_seqlens[i - 1]:q_cu_seqlens[i],
+ ] = True
+ q = q.transpose(0, 1)
+ k = k.transpose(0, 1)
+ v = v.transpose(0, 1)
+ attn_output = F.scaled_dot_product_attention(q,
+ k,
+ v,
+ attention_mask,
+ dropout_p=0.0)
+ attn_output = attn_output.transpose(0, 1)
+ attn_output = attn_output.reshape(seq_length, -1)
+ return attn_output
+
+
+VL_VISION_ATTENTION_FUNCTIONS = {
+ "flash_attention_2": multihead_attention,
+ "sdpa": sdpa_attention,
+}
+
+
+def _apply_rope_input_validation(x, freqs_cis):
+ assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
+ assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
+ assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
+ assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
+
+
+def apply_rope(xq: torch.Tensor, xk: torch.Tensor,
+ freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args: (The leading dimensions of all inputs should be the same)
+ xq: query, tensor of shape (..., num_heads, head_dim)
+ xk: key, tensor of shape (..., num_heads, head_dim)
+ freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
+ Returns:
+ xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
+ """
+ _apply_rope_input_validation(xq, freqs_cis)
+ _apply_rope_input_validation(xk, freqs_cis)
+
+ freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
+ # ..., num_heads, head_dim/2
+ xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
+ xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(
+ -2) # ..., num_heads, head_dim
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(
+ -2) # ..., num_heads, head_dim
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class Learnable2DInterpPosEmb(nn.Module):
+
+ def __init__(self,
+ height: int,
+ width: int,
+ dim: int,
+ interpolation_mode: str = "bicubic") -> None:
+ super().__init__()
+ self.height = height
+ self.width = width
+ self.interpolation_mode = interpolation_mode
+ self.weight = nn.Parameter(torch.empty(height, width, dim))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.normal_(self.weight)
+
+ def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
+ pos_embs = []
+ for shape in grid_hws.tolist():
+ if shape == self.weight.shape[:-1]:
+ pos_embs.append(self.weight.flatten(end_dim=1))
+ else:
+ pos_embs.append(
+ F.interpolate(
+ self.weight.permute((2, 0, 1)).unsqueeze(0),
+ size=shape,
+ mode=self.interpolation_mode,
+ ).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1))
+ out = x + torch.cat(pos_embs)
+ return out
+
+
+class MoonVisionPatchEmbed(nn.Module):
+
+ def __init__(
+ self,
+ out_dim: int,
+ in_dim: int = 3,
+ patch_size: Union[int, Tuple[int, int]] = (14, 14),
+ pos_emb_height: int = 14,
+ pos_emb_width: int = 14,
+ ):
+ super().__init__()
+ assert isinstance(
+ patch_size,
+ (int, Sequence)), f"Invalid patch_size type: {type(patch_size)}"
+ if isinstance(patch_size, int):
+ patch_size = (patch_size, patch_size)
+ assert (len(patch_size) == 2
+ ), f"Expected patch_size to be a tuple of 2, got {patch_size}"
+ self.patch_size = patch_size
+
+ self.proj = nn.Conv2d(in_dim,
+ out_dim,
+ kernel_size=patch_size,
+ stride=patch_size)
+
+ self.pos_emb = Learnable2DInterpPosEmb(height=pos_emb_height,
+ width=pos_emb_width,
+ dim=out_dim)
+
+ def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (L, Channels): input tensor
+ grid_hw (N, 2): grid height and width
+
+ Returns:
+ (L, Cout) tensor
+ """
+ x = self.proj(x).view(x.size(0), -1)
+ # apply positional embedding
+ x = self.pos_emb(x, grid_hw)
+ return x
+
+
+class Rope2DPosEmb(nn.Module):
+ """2D rotary position embedding with multi-resolution support.
+
+ This class is intended to be used in the following way:
+ 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
+ 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
+ 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
+ The rope is shared across all attention layers and all heads.
+
+ Refs:
+ - RoFormer: https://arxiv.org/abs/2104.09864
+ - VisionLLaMA: https://arxiv.org/abs/2403.00522
+ - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
+
+ Args:
+ dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
+ max_height (int): the maximum height of the 2D grid
+ max_width (int): the maximum width of the 2D grid
+ theta_base (float): the base of the theta
+ device (str): the device to store the precomputed cis
+ """
+
+ def __init__(self,
+ dim: int,
+ max_height: int,
+ max_width: int,
+ theta_base=10000,
+ device="cuda"):
+ super().__init__()
+ self.dim = dim
+ assert self.dim % 4 == 0, "dim must be divisible by 4"
+ self.max_height = max_height
+ self.max_width = max_width
+ self.theta_base = theta_base
+ self.device = device
+
+ def extra_repr(self):
+ return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
+
+ @cached_property
+ def precomputed_freqs_cis(self) -> torch.Tensor:
+ """Calculate the cis(freqs) for each position in the 2D grid.
+
+ Return: complex tensor of shape (max_height, max_width, dim//2) and value:
+ height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
+ weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
+ note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
+ """
+ N = self.max_height * self.max_width
+ flat_pos = torch.arange(0, N).float().to(self.device)
+ x_pos = flat_pos % self.max_width
+ y_pos = flat_pos // self.max_width
+ dim_range = (torch.arange(0, self.dim,
+ 4)[:(self.dim // 4)].float().to(self.device)
+ ) # C/4
+ freqs = 1.0 / (self.theta_base**(dim_range / self.dim))
+ x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
+ # N, C/4, 2
+ freqs_cis = torch.cat(
+ [x_cis.unsqueeze(dim=-1),
+ y_cis.unsqueeze(dim=-1)], dim=-1)
+ # max_height, max_width, C/2
+ freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
+ return freqs_cis
+
+ def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples.
+ Returns:
+ freqs_cis: tensor of shape (sum(t * height * width), dim//2)
+ """
+ shapes = grid_hws.tolist()
+ assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width
+ for h, w in shapes), (
+ shapes,
+ self.max_height,
+ self.max_width,
+ )
+ freqs_cis = torch.cat(
+ [
+ self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
+ for h, w in shapes
+ ],
+ dim=0,
+ )
+ return freqs_cis
+
+ def get_freqs_cis_by_idx(self, pos_idx: torch.Tensor,
+ pos_idx_mask: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
+ pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
+ Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
+ Return:
+ freqs_cis: tensor of shape (..., dim//2)
+ """
+ assert (pos_idx.shape[:-1] == pos_idx_mask.shape
+ and pos_idx.shape[-1] == 2 and pos_idx.ndim
+ == pos_idx_mask.ndim + 1), (pos_idx.shape, pos_idx_mask.shape)
+ assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype
+
+ shp = pos_idx_mask.shape + (self.dim // 2, ) # ..., head_dim/2
+ freqs_cis = torch.ones(shp, dtype=torch.complex64,
+ device=self.device) # ..., head_dim/2
+ freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[pos_idx[
+ ..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]]
+ return freqs_cis
+
+
+class MLP2(nn.Module):
+ """
+ Args:
+ dims: [in_dim, hidden_dim, out_dim]
+ bias: whether to use bias in linear layer.
+ """
+
+ def __init__(self, dims: list[int], activation, bias=True):
+ super().__init__()
+ assert len(dims) == 3
+ self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
+ self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
+ self.activation = activation
+ for m in [self.fc0, self.fc1]:
+ nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.fc0(x)
+ x = self.activation(x)
+ return self.fc1(x)
+
+
+class MoonVitEncoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ num_heads: int,
+ hidden_dim: int,
+ mlp_dim: int,
+ *,
+ attn_implementation: str = "sdpa",
+ activation=F.gelu,
+ attn_bias: bool = False,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.hidden_dim = hidden_dim
+ self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
+ self.attn_implementation = attn_implementation
+ # use fa2 in vllm by default
+ if is_flash_attn_2_available():
+ self.attn_implementation = "flash_attention_2"
+
+ self.norm0 = nn.LayerNorm(hidden_dim)
+ self.norm1 = nn.LayerNorm(hidden_dim)
+ self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
+ self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
+ self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
+
+ def attention_qkvpacked(
+ self,
+ x: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rope_freqs_cis: Optional[torch.Tensor] = None,
+ ):
+ """
+ Args:
+ x (torch.Tensor): (batch_size, seqlen, hidden_dim)
+ cu_seqlens (torch.Tensor):
+ """
+ xqkv = self.wqkv(x)
+
+ qkv_shape = xqkv.size()[:-1] + (
+ 3,
+ self.num_heads,
+ self.hidden_size_per_attention_head,
+ )
+ # xqkv: (batch_size, seqlen, 3, nheads, headdim)
+ xqkv = xqkv.view(*qkv_shape)
+ xq, xk, xv = torch.unbind(xqkv, dim=-3)
+
+ xq, xk = apply_rope(xq, xk, rope_freqs_cis)
+
+ attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
+ attn_out = attn_func(xq,
+ xk,
+ xv,
+ q_cu_seqlens=cu_seqlens,
+ k_cu_seqlens=cu_seqlens)
+
+ attn_out = self.wo(attn_out)
+ return attn_out
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rope_freqs_cis: Union[torch.Tensor, None] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set
+
+ Returns:
+ output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input
+ """
+ residual = hidden_states
+ hidden_states = self.norm0(hidden_states)
+ attn_out = self.attention_qkvpacked(hidden_states,
+ cu_seqlens,
+ rope_freqs_cis=rope_freqs_cis)
+ hidden_states = residual + attn_out
+
+ residual = hidden_states
+ hidden_states = self.mlp(self.norm1(hidden_states))
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class MoonVitEncoder(nn.Module):
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ num_layers: int,
+ block_cfg: dict,
+ ) -> None:
+ super().__init__()
+
+ self.rope_2d = Rope2DPosEmb(
+ block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512)
+ self.blocks = nn.ModuleList(
+ [MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)])
+ self.final_layernorm = nn.LayerNorm(hidden_dim)
+
+ def forward(self, hidden_states: torch.Tensor,
+ grid_hw: torch.Tensor) -> torch.Tensor:
+ rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(
+ grid_hws=grid_hw)
+
+ lengths = torch.cat((
+ torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
+ grid_hw[:, 0] * grid_hw[:, 1],
+ ))
+ cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
+
+ for _, block in enumerate(self.blocks):
+ hidden_states = block(hidden_states,
+ cu_seqlens,
+ rope_freqs_cis=rope_freqs_cis)
+
+ hidden_states = self.final_layernorm(hidden_states)
+
+ return hidden_states
+
+
+def patch_merger(
+ x: torch.Tensor,
+ grid_hw: torch.Tensor,
+ merge_kernel_size: list[int, int] = (2, 2),
+) -> List[torch.Tensor]:
+ d_model = x.size(-1)
+
+ outputs = []
+ pre_sum = 0
+ for x_shape in grid_hw.tolist():
+ height, width = x_shape[0], x_shape[1]
+ # Get the current sequence
+ seq = x[pre_sum:pre_sum + height * width]
+ # Reshape along self.merge_kernel_size and concat to the last dimension
+ kernel_height, kernel_width = merge_kernel_size
+ new_height, new_width = height // kernel_height, width // kernel_width
+ reshaped_seq = seq.view(new_height, kernel_height, new_width,
+ kernel_width, d_model)
+ reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous()
+ padded_seq = reshaped_seq.view(new_height * new_width,
+ kernel_height * kernel_width, -1)
+ outputs.append(padded_seq)
+ pre_sum += height * width
+
+ return outputs
+
+
+class MoonVitVLProjector(nn.Module):
+
+ def __init__(
+ self,
+ in_channels: int,
+ merge_kernel_size: list[int, int],
+ hidden_act: str = "gelu",
+ ln_eps: float = 1e-5,
+ out_dim: int = 4096,
+ ):
+ super().__init__()
+ self.hidden_size = in_channels * merge_kernel_size[
+ 0] * merge_kernel_size[1]
+
+ self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
+ self.linear_1 = nn.Linear(self.hidden_size,
+ self.hidden_size,
+ bias=True)
+ self.act = ACT2FN[hidden_act]
+ self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
+ hidden_states = self.linear_1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class MoonVitPretrainedModel(PreTrainedModel):
+ config_class = MoonViTConfig
+ model_type = "moonvit"
+ _no_split_modules = ["PackingTransformer"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ config = deepcopy(config)
+ self.merge_kernel_size = config.merge_kernel_size
+ self.patch_size = config.patch_size
+ self.patch_embed = MoonVisionPatchEmbed(
+ out_dim=config.hidden_size,
+ patch_size=config.patch_size,
+ pos_emb_height=config.init_pos_emb_height,
+ pos_emb_width=config.init_pos_emb_width,
+ )
+
+ self.encoder = MoonVitEncoder(
+ hidden_dim=config.hidden_size,
+ num_layers=config.num_hidden_layers,
+ block_cfg={
+ "num_heads": config.num_attention_heads,
+ "hidden_dim": config.hidden_size,
+ "mlp_dim": config.intermediate_size,
+ "activation": PytorchGELUTanh(),
+ "attn_bias": True,
+ "attn_implementation": config._attn_implementation,
+ },
+ )
+
+ def forward(self, pixel_values: torch.Tensor,
+ grid_hw: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (torch.Tensor): The input pixel values.
+ grid_hw (torch.Tensor): The grid height and width.
+
+ Returns:
+ torch.Tensor: The output tokens.
+ """
+ hidden_states = self.patch_embed(pixel_values, grid_hw)
+ hidden_states = self.encoder(hidden_states, grid_hw)
+ hidden_states = patch_merger(hidden_states,
+ grid_hw,
+ merge_kernel_size=self.merge_kernel_size)
+ return hidden_states
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 0d13d699..b345113e 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -177,6 +177,7 @@ _MULTIMODAL_MODELS = {
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
+ "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index fe0319c9..f37605be 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -33,12 +33,13 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
EAGLEConfig, ExaoneConfig,
H2OVLChatConfig,
InternVLChatConfig, JAISConfig,
- MedusaConfig, MllamaConfig,
- MLPSpeculatorConfig, MPTConfig,
- NemotronConfig, NVLM_D_Config,
- Olmo2Config, RWConfig,
- SkyworkR1VChatConfig, SolarConfig,
- Telechat2Config, UltravoxConfig)
+ KimiVLConfig, MedusaConfig,
+ MllamaConfig, MLPSpeculatorConfig,
+ MPTConfig, NemotronConfig,
+ NVLM_D_Config, Olmo2Config,
+ RWConfig, SkyworkR1VChatConfig,
+ SolarConfig, Telechat2Config,
+ UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
@@ -62,6 +63,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"cohere2": Cohere2Config,
"dbrx": DbrxConfig,
"deepseek_vl_v2": DeepseekVLV2Config,
+ "kimi_vl": KimiVLConfig,
"mpt": MPTConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index 53699341..739eea5c 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -13,9 +13,11 @@ from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
from vllm.transformers_utils.configs.jais import JAISConfig
+from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm.transformers_utils.configs.mllama import MllamaConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
+from vllm.transformers_utils.configs.moonvit import MoonViTConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
@@ -40,6 +42,8 @@ __all__ = [
"ExaoneConfig",
"MllamaConfig",
"MLPSpeculatorConfig",
+ "MoonViTConfig",
+ "KimiVLConfig",
"NemotronConfig",
"NVLM_D_Config",
"Olmo2Config",
diff --git a/vllm/transformers_utils/configs/kimi_vl.py b/vllm/transformers_utils/configs/kimi_vl.py
new file mode 100644
index 00000000..97ff44bb
--- /dev/null
+++ b/vllm/transformers_utils/configs/kimi_vl.py
@@ -0,0 +1,36 @@
+# SPDX-License-Identifier: Apache-2.0
+# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
+from typing import Optional, Union
+
+from transformers.configuration_utils import PretrainedConfig
+
+from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
+from vllm.transformers_utils.configs.moonvit import MoonViTConfig
+
+
+class KimiVLConfig(PretrainedConfig):
+ model_type = "kimi_vl"
+
+ def __init__(self,
+ vision_config: Optional[Union[dict, MoonViTConfig]] = None,
+ text_config: Optional[Union[dict, DeepseekV2Config]] = None,
+ ignore_index: int = -100,
+ media_placeholder_token_id: int = 163605,
+ pad_token_id: int = 0,
+ **kwargs):
+ if vision_config is None:
+ vision_config = MoonViTConfig()
+ elif isinstance(vision_config, dict):
+ vision_config = MoonViTConfig(**vision_config)
+ self.vision_config = vision_config
+
+ if text_config is None:
+ text_config = DeepseekV2Config()
+ elif isinstance(text_config, dict):
+ text_config = DeepseekV2Config(**text_config)
+ self.text_config = text_config
+
+ self.ignore_index = ignore_index
+ self.media_placeholder_token_id = media_placeholder_token_id
+
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
diff --git a/vllm/transformers_utils/configs/moonvit.py b/vllm/transformers_utils/configs/moonvit.py
new file mode 100644
index 00000000..a2b4059a
--- /dev/null
+++ b/vllm/transformers_utils/configs/moonvit.py
@@ -0,0 +1,32 @@
+# SPDX-License-Identifier: Apache-2.0
+# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
+from transformers.configuration_utils import PretrainedConfig
+
+
+class MoonViTConfig(PretrainedConfig):
+ model_type = "moonvit"
+
+ def __init__(
+ self,
+ patch_size: int = 14,
+ init_pos_emb_height: int = 64,
+ init_pos_emb_width: int = 64,
+ num_attention_heads: int = 16,
+ num_hidden_layers: int = 27,
+ hidden_size: int = 1152,
+ intermediate_size: int = 4304,
+ merge_kernel_size: tuple[int, int] = (2, 2),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.patch_size = patch_size
+ # Positional embedding config
+ self.init_pos_emb_height = init_pos_emb_height
+ self.init_pos_emb_width = init_pos_emb_width
+ # Transformer config
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ # Patch merger config
+ self.merge_kernel_size = merge_kernel_size
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 70e8bd75..c3d84ab3 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -997,13 +997,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
- if self.is_multimodal_model:
- # Run the multimodal encoder if any.
- self._execute_mm_encoder(scheduler_output)
- mm_embeds = self._gather_mm_embeddings(scheduler_output)
- else:
- mm_embeds = []
-
# Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = (
self._prepare_inputs(scheduler_output))
@@ -1019,6 +1012,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_input_tokens = num_scheduled_tokens
attn_metadata.num_input_tokens = num_input_tokens
+ # _prepare_inputs may reorder the batch, so we must gather multi
+ # modal outputs after that to ensure the correct order
+ if self.is_multimodal_model:
+ # Run the multimodal encoder if any.
+ self._execute_mm_encoder(scheduler_output)
+ mm_embeds = self._gather_mm_embeddings(scheduler_output)
+ else:
+ mm_embeds = []
+
if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)