From d55244df31969e7df435603b5d7014939e60881b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Wed, 9 Apr 2025 13:12:54 +0200 Subject: [PATCH] [Model] Add `SupportsMultiModal.get_language_model` interface (#16007) Signed-off-by: NickLucche --- docs/source/contributing/model/multimodal.md | 11 +++++++++++ vllm/model_executor/models/aria.py | 3 +++ vllm/model_executor/models/aya_vision.py | 3 +++ vllm/model_executor/models/blip2.py | 3 +++ vllm/model_executor/models/chameleon.py | 3 +++ vllm/model_executor/models/deepseek_vl2.py | 3 +++ vllm/model_executor/models/florence2.py | 3 +++ vllm/model_executor/models/fuyu.py | 3 +++ vllm/model_executor/models/gemma3_mm.py | 3 +++ vllm/model_executor/models/glm4v.py | 3 +++ vllm/model_executor/models/idefics3.py | 3 +++ vllm/model_executor/models/interfaces.py | 12 ++++++++++++ vllm/model_executor/models/internvl.py | 3 +++ vllm/model_executor/models/llava.py | 3 +++ vllm/model_executor/models/llava_next.py | 3 +++ vllm/model_executor/models/llava_next_video.py | 3 +++ vllm/model_executor/models/llava_onevision.py | 3 +++ vllm/model_executor/models/minicpmv.py | 3 +++ vllm/model_executor/models/mistral3.py | 3 +++ vllm/model_executor/models/mllama.py | 3 +++ vllm/model_executor/models/mllama4.py | 3 +++ vllm/model_executor/models/molmo.py | 3 +++ vllm/model_executor/models/paligemma.py | 3 +++ vllm/model_executor/models/phi3v.py | 3 +++ vllm/model_executor/models/phi4mm.py | 3 +++ vllm/model_executor/models/pixtral.py | 3 +++ vllm/model_executor/models/qwen2_5_vl.py | 3 +++ vllm/model_executor/models/qwen2_audio.py | 3 +++ vllm/model_executor/models/qwen2_vl.py | 3 +++ vllm/model_executor/models/qwen_vl.py | 3 +++ vllm/model_executor/models/skyworkr1v.py | 3 +++ vllm/model_executor/models/ultravox.py | 3 +++ vllm/model_executor/models/whisper.py | 3 +++ 33 files changed, 116 insertions(+) diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md index c4894d39..0c749633 100644 --- a/docs/source/contributing/model/multimodal.md +++ b/docs/source/contributing/model/multimodal.md @@ -79,6 +79,17 @@ Further update the model as follows: return inputs_embeds ``` +- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model` getter to provide stable access to the underlying language model. + + ```python + class YourModelForImage2Seq(nn.Module): + ... + + def get_language_model(self) -> torch.nn.Module: + # Change `language_model` according to your implementation. + return self.language_model + ``` + - Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface. ```diff diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 8cd3be90..af340fef 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -605,6 +605,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): return self.multi_modal_projector(image_outputs, image_attn_mask) + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 6b68885d..929c8f2a 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -424,6 +424,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, num_patches=num_patches, ) + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index db9d42f5..a1f20ea4 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -627,6 +627,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, return self.language_projection(query_output) + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 3d527cb6..d46ae532 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -988,6 +988,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, data=self._validate_pixel_values(pixel_values), ) + def get_language_model(self) -> torch.nn.Module: + return self.model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 4554a997..03d5be29 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -604,6 +604,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): return self._pixel_values_to_embedding( pixel_values=pixel_values, images_spatial_crop=images_spatial_crop) + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 70b8d51b..62fd0939 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -1050,6 +1050,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, pixel_values = image_input["data"] return self._encode_image(pixel_values) + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 189b91db..c0a0f572 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -341,6 +341,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): return vision_embeddings_flat.split(patches_per_image, dim=0) + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 9552ee1f..93d0aa30 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -591,6 +591,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist()) ] + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index c190a458..6d7b760d 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -596,6 +596,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, return self.transformer.vision(pixel_values) + def get_language_model(self) -> torch.nn.Module: + return self.transformer + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 4b0513aa..ec02d1c8 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -710,6 +710,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, e.flatten(0, 1) for e in image_features.split(num_patches.tolist()) ] + def get_language_model(self) -> torch.nn.Module: + return self.model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index c61254ac..0cda199a 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -56,6 +56,18 @@ class SupportsMultiModal(Protocol): """ ... + def get_language_model(self) -> torch.nn.Module: + """ + Returns the underlying language model used for text generation. + + This is typically the `torch.nn.Module` instance responsible for + processing the merged multimodal embeddings and producing hidden states + + Returns: + torch.nn.Module: The core language model component. + """ + ... + # Only for models that support v0 chunked prefill # TODO(ywang96): Remove this overload once v0 is deprecated @overload diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index cf5608e3..7fd628fa 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -884,6 +884,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): else: self.visual_token_mask = None + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index b34ac38f..95165500 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -674,6 +674,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): image_embeds = torch.split(image_embeds, feature_sizes) return image_embeds + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 4de13e54..9c4d0e1f 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -480,6 +480,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, for i, patch_features_batch in enumerate(patch_embeddings) ] + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 780af72d..6fc4c187 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -421,6 +421,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, return [e.flatten(0, 1) for e in embeds] + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: video_input = self._parse_and_validate_video_input(**kwargs) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index c7e13bb3..5fbd27b9 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -852,6 +852,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, image_feature = image_feature.view(batch_frames, -1, dim) return image_feature + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index eb20a963..12b5364c 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -892,6 +892,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): return multimodal_embeddings + def get_language_model(self) -> torch.nn.Module: + return self.llm + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index b6fbc6b1..67c0e2ec 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -514,6 +514,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = (image_embeds, ) return image_embeds + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 6a2e2084..a67339ca 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1325,6 +1325,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, cross_attention_states = cross_attention_states_flat return cross_attention_states + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_cross_attention_states( self, image_inputs: MllamaImagePixelInputs, diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index d76d6377..0499fe09 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -742,6 +742,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, for img in vision_embeddings_flat.split(patches_per_image, dim=0) ] + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings(self, **kwargs) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 6857bfa8..a7551e61 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1488,6 +1488,9 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ) ] + def get_language_model(self) -> torch.nn.Module: + return self.model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 845f77ac..274163ac 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -323,6 +323,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, return self.multi_modal_projector(image_features) + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index d3b0688f..344f348c 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -674,6 +674,9 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, return image_embeds + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index cb75ee1e..ec19797f 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1802,3 +1802,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, connector=["audio_projection_for_vision", "audio_projection"], tower_model=["vision_encoder", "embed_tokens_extend"], ) + + def get_language_model(self) -> torch.nn.Module: + return self.model diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index e07c6516..328d5271 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -396,6 +396,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = torch.split(image_embeds, feature_sizes) return image_embeds + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 1e6ff1fe..84b7e59c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -967,6 +967,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, **kwargs) return modalities + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 54220037..9f2593fc 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -355,6 +355,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, return torch.split(masked_audio_features, audio_output_lengths.flatten().tolist()) + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: audio_input = self._parse_and_validate_audio_input(**kwargs) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index a7800d41..f93654d0 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1276,6 +1276,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, return modalities + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index a2ec9a9a..2e941f3b 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -740,6 +740,9 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, return self.transformer.visual(image_input["data"]) + def get_language_model(self) -> torch.nn.Module: + return self.transformer + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index e3deae82..a8460a2e 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -889,6 +889,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): else: self.visual_token_mask = None + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 6e73a2ae..6e9d1526 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -563,6 +563,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ] return flattened_embeddings.split(embed_lens) + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: audio_input = self._parse_and_validate_audio_input(**kwargs) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index e83abbe8..7751f96d 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -692,6 +692,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ) return decoder_outputs + def get_language_model(self) -> torch.nn.Module: + return self.model.decoder + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: # TODO: This method does not obey the interface for SupportsMultiModal.