diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 6d214054..1b742717 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -753,6 +753,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
+- * `AyaVisionForConditionalGeneration`
+ * Aya Vision
+ * T + I+
+ * `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc.
+ *
+ * ✅︎
+ * ✅︎
- * `Blip2ForConditionalGeneration`
* BLIP-2
* T + IE
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index d32bfcd3..c1115708 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -60,6 +60,28 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData:
)
+# Aya Vision
+def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData:
+ assert modality == "image"
+ model_name = "CohereForAI/aya-vision-8b"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=2048,
+ max_num_seqs=2,
+ mm_processor_kwargs={"crop_to_patches": True},
+ disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
+ )
+ prompts = [
+ f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{question}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
+ for question in questions
+ ]
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+
# BLIP-2
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@@ -865,6 +887,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
model_example_map = {
"aria": run_aria,
+ "aya_vision": run_aya_vision,
"blip-2": run_blip2,
"chameleon": run_chameleon,
"deepseek_vl_v2": run_deepseek_vl2,
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index 318cf989..39951e5e 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -61,6 +61,41 @@ def load_aria(question: str, image_urls: list[str]) -> ModelRequestData:
)
+def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
+ model_name = "CohereForAI/aya-vision-8b"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_num_seqs=2,
+ limit_mm_per_prompt={"image": len(image_urls)},
+ )
+
+ placeholders = [{"type": "image", "image": url} for url in image_urls]
+ messages = [{
+ "role":
+ "user",
+ "content": [
+ *placeholders,
+ {
+ "type": "text",
+ "text": question
+ },
+ ],
+ }]
+
+ processor = AutoProcessor.from_pretrained(model_name)
+
+ prompt = processor.apply_chat_template(messages,
+ tokenize=False,
+ add_generation_prompt=True)
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompt=prompt,
+ image_data=[fetch_image(url) for url in image_urls],
+ )
+
+
def load_deepseek_vl2(question: str,
image_urls: list[str]) -> ModelRequestData:
model_name = "deepseek-ai/deepseek-vl2-tiny"
@@ -526,6 +561,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_example_map = {
"aria": load_aria,
+ "aya_vision": load_aya_vision,
"deepseek_vl_v2": load_deepseek_vl2,
"gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl,
diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index 7a9158ef..3b34f012 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -158,6 +158,20 @@ VLM_TEST_SETTINGS = {
max_tokens=64,
marks=[large_gpu_mark(min_gb=64)],
),
+ "aya_vision": VLMTestInfo(
+ models=["CohereForAI/aya-vision-8b"],
+ test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
+ prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501
+ single_image_prompts=IMAGE_ASSETS.prompts({
+ "stop_sign": "What's the content in the center of the image?", # noqa: E501
+ "cherry_blossom": "What is the season?", # noqa: E501
+ }),
+ multi_image_prompt="Describe the two images in detail.", # noqa: E501
+ max_model_len=8192,
+ max_num_seqs=2,
+ auto_cls=AutoModelForImageTextToText,
+ vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}}
+ ),
"blip2": VLMTestInfo(
# TODO: Change back to 2.7b once head_dim = 80 is supported
models=["Salesforce/blip2-opt-6.7b"],
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index e4f1d297..fdcd7a9e 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -246,6 +246,7 @@ def _test_processing_correctness_mistral(
# yapf: disable
@pytest.mark.parametrize("model_id", [
"rhymes-ai/Aria",
+ "CohereForAI/aya-vision-8b",
"Salesforce/blip2-opt-2.7b",
"facebook/chameleon-7b",
"deepseek-ai/deepseek-vl2-tiny",
diff --git a/tests/models/registry.py b/tests/models/registry.py
index ffc00261..137f1418 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -259,6 +259,7 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
_MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
+ "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
diff --git a/vllm/config.py b/vllm/config.py
index c82c9763..c213c9b4 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -2716,6 +2716,10 @@ def _get_and_verify_max_len(
max_len_key = key if max_len < derived_max_model_len \
else max_len_key
derived_max_model_len = min(derived_max_model_len, max_len)
+ # For Command-R / Cohere, Cohere2 / Aya Vision models
+ if tmp_max_len := getattr(hf_config, "model_max_length", None):
+ max_len_key = "model_max_length"
+ derived_max_model_len = tmp_max_len
# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index e32b8ffc..ff2d1aac 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -496,8 +496,9 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
- if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
- "skywork_chat", "NVLM_D", "h2ovl_chat"):
+ if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
+ "internvl_chat", "skywork_chat", "NVLM_D",
+ "h2ovl_chat"):
return ""
if model_type == "mllama":
return "<|image|>"
diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py
new file mode 100644
index 00000000..b4bf1d82
--- /dev/null
+++ b/vllm/model_executor/models/aya_vision.py
@@ -0,0 +1,527 @@
+# SPDX-License-Identifier: Apache-2.0 Adapted from
+# https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
+from functools import cached_property
+from typing import (Iterable, Literal, Mapping, Optional, Sequence, Set, Tuple,
+ TypedDict, Union, cast)
+
+import torch
+from torch import nn
+from transformers import BatchFeature, GotOcr2ImageProcessor
+from transformers.activations import ACT2FN
+from transformers.image_processing_utils import get_size_dict
+from transformers.models.aya_vision import AyaVisionConfig
+from transformers.models.aya_vision.processing_aya_vision import (
+ AyaVisionProcessor)
+from transformers.models.got_ocr2.image_processing_got_ocr2 import (
+ get_optimal_tiled_canvas)
+
+from vllm.config import VllmConfig
+from vllm.jsontree import json_map_leaves
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import MultiModalKwargs
+from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
+ MultiModalDataItems)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo,
+ MultiModalFieldConfig,
+ PromptReplacement, PromptUpdate,
+ encode_tokens)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
+from .siglip import SiglipVisionModel
+from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
+ maybe_prefix, merge_multimodal_embeddings)
+from .vision import scatter_patch_features, select_patch_features
+
+
+class AyaVisionImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ pixel_values: torch.Tensor
+ """
+ Shape: `(num_patches_total, num_channels, height, width)`
+
+ `num_patches_total` is the total number of patches over each image over each
+ prompt in the batch.
+ """
+
+ num_patches: torch.Tensor
+ """Shape: `(batch_size * num_images)`"""
+
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
+
+class AyaVisionMultiModalProjector(nn.Module):
+
+ def __init__(self, config: AyaVisionConfig):
+ super().__init__()
+ self.config = config
+ self.downsample_factor = config.downsample_factor
+ self.alignment_intermediate_size = getattr(
+ config, "alignment_intermediate_size",
+ config.text_config.hidden_size)
+ self.layernorm = nn.LayerNorm(config.vision_config.hidden_size *
+ (config.downsample_factor**2),
+ eps=config.adapter_layer_norm_eps)
+
+ self.linear_1 = nn.Linear(
+ config.vision_config.hidden_size * (config.downsample_factor**2),
+ self.alignment_intermediate_size,
+ bias=True,
+ )
+
+ self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation
+ # For SwiGLU, project down to half size since we split intermediate dim
+ self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2,
+ config.text_config.hidden_size,
+ bias=True)
+
+ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
+ image_features = self.pixel_shuffle(image_features)
+ image_features = self.layernorm(image_features)
+ hidden_states = self.linear_1(image_features)
+
+ # Split along last dimension and apply SwiGLU
+ x, gate = hidden_states.chunk(2, dim=-1)
+ hidden_states = self.act(gate) * x
+
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+ def pixel_shuffle(self,
+ image_features: torch.Tensor) -> torch.Tensor: # B, S, D
+ batch_size, seq_length, _ = image_features.shape
+ height = width = int(seq_length**0.5)
+ image_features = image_features.reshape(image_features.shape[0], width,
+ height, -1)
+ channels = image_features.shape[-1]
+ image_features = image_features.reshape(
+ batch_size, width, int(height / self.downsample_factor),
+ int(channels * self.downsample_factor))
+ image_features = image_features.permute(0, 2, 1, 3)
+ image_features = image_features.reshape(
+ batch_size, int(height / self.downsample_factor),
+ int(width / self.downsample_factor), -1)
+ image_features = image_features.permute(0, 2, 1, 3)
+ return image_features
+
+
+class AyaVisionProcessingInfo(BaseProcessingInfo):
+
+ def get_hf_config(self) -> AyaVisionConfig:
+ return self.ctx.get_hf_config(AyaVisionConfig)
+
+ def get_hf_processor(self, **kwargs: object) -> AyaVisionProcessor:
+ return self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs)
+
+ def get_image_processor(self) -> GotOcr2ImageProcessor:
+ return self.get_hf_processor().image_processor
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ return {"image": self.get_max_image_tokens()}
+
+ def get_max_image_tokens(self) -> int:
+ hf_processor = self.get_hf_processor()
+ image_processor = hf_processor.image_processor
+ image_size = self.get_image_size_with_most_features()
+ tokenizer = hf_processor.tokenizer
+ num_patches = self.get_num_patches(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ size=image_processor.size,
+ min_patches=image_processor.min_patches,
+ max_patches=image_processor.max_patches)
+ image_string = hf_processor._prompt_split_image(num_patches)
+ x = encode_tokens(
+ tokenizer,
+ image_string,
+ add_special_tokens=False,
+ )
+ return len(x)
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": None}
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ image_processor = self.get_image_processor()
+ height = image_processor.size['height']
+ width = image_processor.size['width']
+ max_patches = image_processor.max_patches
+ return ImageSize(height=height * max_patches,
+ width=width * max_patches)
+
+ def get_num_patches(self, *, image_width: int, image_height: int,
+ size: dict, min_patches: int, max_patches: int) -> int:
+ """
+ Calculate the number of patches needed for a given image based on size
+ constraints. This method replicates and adjusts the logic from:
+ transformers/models/got_ocr2/image_processing_got_ocr2
+ """
+ size = get_size_dict(size, default_to_square=False)
+ num_columns, num_rows = get_optimal_tiled_canvas(
+ (image_height, image_width), (size["height"], size["width"]),
+ min_patches, max_patches)
+ num_blocks = num_columns * num_rows
+ return num_blocks if num_blocks == 1 else num_blocks + 1
+
+
+class AyaVisionDummyInputsBuilder(
+ BaseDummyInputsBuilder[AyaVisionProcessingInfo]):
+
+ def get_dummy_processor_inputs(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ processor = self.info.get_hf_processor()
+ image_token = processor.image_token
+
+ num_images = mm_counts.get("image", 0)
+ image_size = \
+ self.info.get_image_size_with_most_features()
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=image_size.width,
+ height=image_size.height,
+ num_images=num_images)
+ }
+ return ProcessorInputs(
+ prompt_text=image_token * num_images,
+ mm_data=mm_data,
+ )
+
+
+class AyaVisionMultiModalProcessor(
+ BaseMultiModalProcessor[AyaVisionProcessingInfo]):
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ processed_outputs = super()._call_hf_processor(
+ prompt,
+ mm_data,
+ mm_kwargs,
+ )
+ hf_processor = self.info.get_hf_processor(**mm_kwargs)
+ image_processor = hf_processor.image_processor
+
+ hf_config = self.info.get_hf_config()
+ # HF processor pops the `num_patches` kwarg, which is needed by vLLM
+ if (images :=
+ mm_data.get("images")) is not None and '' in prompt:
+ assert isinstance(images, list)
+ parsed_images = (self._get_data_parser().parse_mm_data({
+ "image":
+ images
+ }).get_items("image", ImageProcessorItems))
+ image_sizes = [
+ parsed_images.get_image_size(i)
+ for i in range(len(parsed_images))
+ ]
+ num_patches = [
+ self.info.get_num_patches(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ size=image_processor.size,
+ min_patches=image_processor.min_patches,
+ max_patches=image_processor.max_patches)
+ for image_size in image_sizes
+ ]
+ image_tokens_list = [
+ hf_processor._prompt_split_image(num_patch)
+ for num_patch in num_patches
+ ]
+ tokenizer = self.info.get_tokenizer()
+ image_token_ids = [
+ tokenizer.encode(image_tokens, add_special_tokens=False)
+ for image_tokens in image_tokens_list
+ ]
+ embed_is_patch = [
+ torch.tensor(image_repl_tokens) == hf_config.image_token_index
+ for image_repl_tokens in image_token_ids
+ ]
+ processed_outputs["embed_is_patch"] = embed_is_patch
+ processed_outputs["num_patches"] = torch.tensor(num_patches)
+
+ return processed_outputs
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ num_patches = hf_inputs.get("num_patches", torch.empty(0))
+ return dict(
+ pixel_values=MultiModalFieldConfig.flat_from_sizes(
+ "image", num_patches),
+ num_patches=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
+ image_embeds=MultiModalFieldConfig.batched("image"),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_token = hf_processor.image_token
+ image_processor = hf_processor.image_processor
+
+ def get_replacement(item_idx: int):
+ images: ImageProcessorItems = mm_items.get("image",
+ ImageProcessorItems)
+ image_size: ImageSize = images.get_image_size(item_idx)
+ num_patches = self.info.get_num_patches(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ size=image_processor.size,
+ min_patches=image_processor.min_patches,
+ max_patches=image_processor.max_patches)
+ return hf_processor._prompt_split_image(num_patches=num_patches)
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=image_token,
+ replacement=get_replacement,
+ )
+ ]
+
+
+def _get_num_hidden_layers(hf_config: AyaVisionConfig) -> int:
+ feature_layers = hf_config.vision_feature_layer
+ num_hidden_layers = hf_config.vision_config.num_hidden_layers
+ # If we have one feature layer, initialize up to that layer
+ if isinstance(feature_layers, int):
+ return _get_layer_index(feature_layers, num_hidden_layers)
+ # If we have multiple feature layers, initialize up to the deepest m
+ elif isinstance(feature_layers, (list, tuple)):
+ return max(
+ _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
+ raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
+ " is not supported")
+
+
+def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
+ if feature_layer_index < 0:
+ return num_hidden_layers + feature_layer_index + 1
+ return feature_layer_index
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ AyaVisionMultiModalProcessor,
+ info=AyaVisionProcessingInfo,
+ dummy_inputs=AyaVisionDummyInputsBuilder)
+class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
+ SupportsPP):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config: AyaVisionConfig = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+ num_hidden_layers = _get_num_hidden_layers(config)
+ self.config = config
+ self.quant_config = quant_config
+ self.multimodal_config = multimodal_config
+
+ self.vision_tower = SiglipVisionModel(
+ config.vision_config,
+ quant_config,
+ num_hidden_layers_override=num_hidden_layers,
+ prefix=maybe_prefix(prefix, "vision_model"))
+ self.vocab_size = config.text_config.vocab_size
+ self.multi_modal_projector = AyaVisionMultiModalProjector(config)
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "model"),
+ # Cohere2ForCausalLM and CohereForCausalLM are the same on vllm
+ architectures=["Cohere2ForCausalLM"])
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> Set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights)
+
+ def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
+ pixel_values: torch.Tensor,
+ **kwargs) -> torch.Tensor:
+ target_dtype = vision_tower.get_input_embeddings().weight.dtype
+ image_features = vision_tower(pixel_values.to(dtype=target_dtype),
+ **kwargs)
+
+ def select_features(leaf: torch.Tensor):
+ return self._select_image_features(
+ leaf,
+ strategy=self.config.vision_feature_select_strategy,
+ )
+
+ return cast(
+ Union[torch.Tensor, tuple[torch.Tensor, ...]],
+ json_map_leaves(select_features, image_features),
+ )
+
+ def _select_image_features(self, image_features: torch.Tensor, *,
+ strategy: str) -> torch.Tensor:
+ if strategy == "default":
+ return image_features[:, 1:]
+ elif strategy == "full":
+ return image_features
+
+ raise ValueError(f"Unexpected select feature strategy: {strategy}")
+
+ def _process_image_input(self, image_input: AyaVisionImagePixelInputs,
+ **kwargs) -> list[torch.Tensor]:
+ assert self.vision_tower is not None
+ pixel_values = image_input["pixel_values"]
+ num_patches = image_input["num_patches"]
+ image_features = self._image_pixels_to_features(
+ self.vision_tower, pixel_values=pixel_values)
+ image_embeds = self.multi_modal_projector(image_features)
+ return [
+ e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())
+ ]
+
+ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
+ h = w = self.config.vision_config.image_size
+ expected_dims = (3, h, w)
+
+ def _validate_shape(d: torch.Tensor):
+ if d.shape != expected_dims:
+ raise ValueError(
+ "The expected shape of pixel values per image per batch "
+ f"is {expected_dims}. You supplied {tuple(d.shape)}.")
+
+ for d in data:
+ _validate_shape(d)
+
+ return data
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]:
+ pixel_values = kwargs.pop("pixel_values", None)
+ num_patches = kwargs.pop("num_patches", None)
+ embed_is_patch = kwargs.pop("embed_is_patch", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+ assert image_embeds is None, "Aya Vision does not support image_embeds."
+
+ if not isinstance(pixel_values, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of pixel values. "
+ f"Got type: {type(pixel_values)}")
+ if num_patches is not None and not isinstance(num_patches,
+ (torch.Tensor, list)):
+ raise ValueError("Incorrect type of num_patches. "
+ f"Got type: {type(num_patches)}")
+
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
+ pixel_values = flatten_bn(pixel_values, concat=True)
+ num_patches = flatten_bn(num_patches, concat=True)
+ embed_is_patch = flatten_bn(embed_is_patch)
+ return AyaVisionImagePixelInputs(
+ type="pixel_values",
+ pixel_values=self._validate_pixel_values(pixel_values),
+ num_patches=num_patches,
+ embed_is_patch=embed_is_patch,
+ )
+
+ def get_multimodal_embeddings(
+ self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ if image_input is None:
+ return None
+ image_features = self._process_image_input(image_input, **kwargs)
+ return scatter_patch_features(
+ image_features,
+ image_input["embed_is_patch"],
+ )
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+ if multimodal_embeddings is not None:
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ multimodal_embeddings=select_patch_features(
+ multimodal_embeddings),
+ placeholder_token_id=self.config.image_token_index)
+
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
+ # condition is for v0 compatibility.
+ elif inputs_embeds is None:
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
+ inputs_embeds = self.get_input_embeddings(input_ids,
+ vision_embeddings)
+ input_ids = None
+
+ hidden_states = self.language_model.model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ @cached_property
+ def sampler(self):
+ if hasattr(self.language_model, "sampler"):
+ return self.language_model.sampler
+
+ return get_sampler()
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ return self.language_model.sample(logits, sampling_metadata)
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 5211cd08..2f1827c1 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -161,6 +161,7 @@ _CROSS_ENCODER_MODELS = {
_MULTIMODAL_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
+ "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"), # noqa: E501
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),