[Model] Add SupportsMultiModal.get_language_model interface (#16007)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-04-09 13:12:54 +02:00 committed by GitHub
parent 04149cce27
commit d55244df31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 116 additions and 0 deletions

View File

@ -79,6 +79,17 @@ Further update the model as follows:
return inputs_embeds
```
- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model` getter to provide stable access to the underlying language model.
```python
class YourModelForImage2Seq(nn.Module):
...
def get_language_model(self) -> torch.nn.Module:
# Change `language_model` according to your implementation.
return self.language_model
```
- Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
```diff

View File

@ -605,6 +605,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.multi_modal_projector(image_outputs, image_attn_mask)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -424,6 +424,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
num_patches=num_patches,
)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -627,6 +627,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return self.language_projection(query_output)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -988,6 +988,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
data=self._validate_pixel_values(pixel_values),
)
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -604,6 +604,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self._pixel_values_to_embedding(
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -1050,6 +1050,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values = image_input["data"]
return self._encode_image(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -341,6 +341,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -591,6 +591,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -596,6 +596,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
return self.transformer.vision(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.transformer
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -710,6 +710,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
]
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -56,6 +56,18 @@ class SupportsMultiModal(Protocol):
"""
...
def get_language_model(self) -> torch.nn.Module:
"""
Returns the underlying language model used for text generation.
This is typically the `torch.nn.Module` instance responsible for
processing the merged multimodal embeddings and producing hidden states
Returns:
torch.nn.Module: The core language model component.
"""
...
# Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated
@overload

View File

@ -884,6 +884,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else:
self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -674,6 +674,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -480,6 +480,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings)
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -421,6 +421,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return [e.flatten(0, 1) for e in embeds]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
video_input = self._parse_and_validate_video_input(**kwargs)

View File

@ -852,6 +852,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
image_feature = image_feature.view(batch_frames, -1, dim)
return image_feature
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)

View File

@ -892,6 +892,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
return multimodal_embeddings
def get_language_model(self) -> torch.nn.Module:
return self.llm
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)

View File

@ -514,6 +514,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = (image_embeds, )
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -1325,6 +1325,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
cross_attention_states = cross_attention_states_flat
return cross_attention_states
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_cross_attention_states(
self,
image_inputs: MllamaImagePixelInputs,

View File

@ -742,6 +742,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
for img in vision_embeddings_flat.split(patches_per_image, dim=0)
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -1488,6 +1488,9 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
)
]
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -323,6 +323,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.multi_modal_projector(image_features)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -674,6 +674,9 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -1802,3 +1802,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
connector=["audio_projection_for_vision", "audio_projection"],
tower_model=["vision_encoder", "embed_tokens_extend"],
)
def get_language_model(self) -> torch.nn.Module:
return self.model

View File

@ -396,6 +396,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -967,6 +967,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
**kwargs)
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:

View File

@ -355,6 +355,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
return torch.split(masked_audio_features,
audio_output_lengths.flatten().tolist())
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
audio_input = self._parse_and_validate_audio_input(**kwargs)

View File

@ -1276,6 +1276,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:

View File

@ -740,6 +740,9 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
return self.transformer.visual(image_input["data"])
def get_language_model(self) -> torch.nn.Module:
return self.transformer
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -889,6 +889,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else:
self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -563,6 +563,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
]
return flattened_embeddings.split(embed_lens)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
audio_input = self._parse_and_validate_audio_input(**kwargs)

View File

@ -692,6 +692,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
)
return decoder_outputs
def get_language_model(self) -> torch.nn.Module:
return self.model.decoder
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
# TODO: This method does not obey the interface for SupportsMultiModal.