[Model] Add SupportsMultiModal.get_language_model
interface (#16007)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
04149cce27
commit
d55244df31
@ -79,6 +79,17 @@ Further update the model as follows:
|
||||
return inputs_embeds
|
||||
```
|
||||
|
||||
- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model` getter to provide stable access to the underlying language model.
|
||||
|
||||
```python
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
...
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
# Change `language_model` according to your implementation.
|
||||
return self.language_model
|
||||
```
|
||||
|
||||
- Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
|
||||
|
||||
```diff
|
||||
|
@ -605,6 +605,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
return self.multi_modal_projector(image_outputs, image_attn_mask)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -424,6 +424,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
num_patches=num_patches,
|
||||
)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -627,6 +627,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
|
||||
return self.language_projection(query_output)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -988,6 +988,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -604,6 +604,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return self._pixel_values_to_embedding(
|
||||
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -1050,6 +1050,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
pixel_values = image_input["data"]
|
||||
return self._encode_image(pixel_values)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -341,6 +341,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return vision_embeddings_flat.split(patches_per_image, dim=0)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -591,6 +591,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
|
||||
]
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -596,6 +596,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
|
||||
return self.transformer.vision(pixel_values)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.transformer
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -710,6 +710,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
|
||||
]
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -56,6 +56,18 @@ class SupportsMultiModal(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
"""
|
||||
Returns the underlying language model used for text generation.
|
||||
|
||||
This is typically the `torch.nn.Module` instance responsible for
|
||||
processing the merged multimodal embeddings and producing hidden states
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: The core language model component.
|
||||
"""
|
||||
...
|
||||
|
||||
# Only for models that support v0 chunked prefill
|
||||
# TODO(ywang96): Remove this overload once v0 is deprecated
|
||||
@overload
|
||||
|
@ -884,6 +884,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
else:
|
||||
self.visual_token_mask = None
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -674,6 +674,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_embeds = torch.split(image_embeds, feature_sizes)
|
||||
return image_embeds
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -480,6 +480,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
for i, patch_features_batch in enumerate(patch_embeddings)
|
||||
]
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -421,6 +421,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return [e.flatten(0, 1) for e in embeds]
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
|
@ -852,6 +852,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_feature = image_feature.view(batch_frames, -1, dim)
|
||||
return image_feature
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
|
@ -892,6 +892,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.llm
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
|
@ -514,6 +514,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_embeds = (image_embeds, )
|
||||
return image_embeds
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -1325,6 +1325,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
cross_attention_states = cross_attention_states_flat
|
||||
return cross_attention_states
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_cross_attention_states(
|
||||
self,
|
||||
image_inputs: MllamaImagePixelInputs,
|
||||
|
@ -742,6 +742,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
for img in vision_embeddings_flat.split(patches_per_image, dim=0)
|
||||
]
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -1488,6 +1488,9 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
)
|
||||
]
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -323,6 +323,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -674,6 +674,9 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
|
||||
return image_embeds
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -1802,3 +1802,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
|
||||
connector=["audio_projection_for_vision", "audio_projection"],
|
||||
tower_model=["vision_encoder", "embed_tokens_extend"],
|
||||
)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
@ -396,6 +396,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_embeds = torch.split(image_embeds, feature_sizes)
|
||||
return image_embeds
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -967,6 +967,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
**kwargs)
|
||||
return modalities
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
|
||||
|
@ -355,6 +355,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return torch.split(masked_audio_features,
|
||||
audio_output_lengths.flatten().tolist())
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
|
@ -1276,6 +1276,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return modalities
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
|
||||
|
@ -740,6 +740,9 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
||||
|
||||
return self.transformer.visual(image_input["data"])
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.transformer
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -889,6 +889,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
else:
|
||||
self.visual_token_mask = None
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -563,6 +563,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
]
|
||||
return flattened_embeddings.split(embed_lens)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
|
@ -692,6 +692,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
)
|
||||
return decoder_outputs
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model.decoder
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
# TODO: This method does not obey the interface for SupportsMultiModal.
|
||||
|
Loading…
x
Reference in New Issue
Block a user