[Model] VLM2Vec, the first multimodal embedding model in vLLM (#9303)

This commit is contained in:
Cyrus Leung 2024-10-16 14:31:00 +08:00 committed by GitHub
parent 7e7eae338d
commit 7abba39ee6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 465 additions and 261 deletions

View File

@ -3,7 +3,7 @@
Supported Models Supported Models
================ ================
vLLM supports a variety of generative Transformer models in `HuggingFace Transformers <https://huggingface.co/models>`_. vLLM supports a variety of generative Transformer models in `HuggingFace (HF) Transformers <https://huggingface.co/models>`_.
The following is the list of model architectures that are currently supported by vLLM. The following is the list of model architectures that are currently supported by vLLM.
Alongside each architecture, we include some popular models that use it. Alongside each architecture, we include some popular models that use it.
@ -19,7 +19,7 @@ Text Generation
* - Architecture * - Architecture
- Models - Models
- Example HuggingFace Models - Example HF Models
- :ref:`LoRA <lora>` - :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>` - :ref:`PP <distributed_serving>`
* - :code:`AquilaForCausalLM` * - :code:`AquilaForCausalLM`
@ -280,7 +280,7 @@ Text Embedding
* - Architecture * - Architecture
- Models - Models
- Example HuggingFace Models - Example HF Models
- :ref:`LoRA <lora>` - :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>` - :ref:`PP <distributed_serving>`
* - :code:`Gemma2Model` * - :code:`Gemma2Model`
@ -303,7 +303,7 @@ Reward Modeling
* - Architecture * - Architecture
- Models - Models
- Example HuggingFace Models - Example HF Models
- :ref:`LoRA <lora>` - :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>` - :ref:`PP <distributed_serving>`
* - :code:`Qwen2ForRewardModel` * - :code:`Qwen2ForRewardModel`
@ -316,7 +316,14 @@ Reward Modeling
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes. As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.
Multimodal Language Models Multimodal Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^
The following modalities are supported depending on the model:
- **T**\ ext
- **I**\ mage
- **V**\ ideo
- **A**\ udio
.. _supported_vlms: .. _supported_vlms:
@ -324,78 +331,78 @@ Text Generation
--------------- ---------------
.. list-table:: .. list-table::
:widths: 25 25 25 25 5 5 :widths: 25 25 15 25 5 5
:header-rows: 1 :header-rows: 1
* - Architecture * - Architecture
- Models - Models
- Modalities - Inputs
- Example HuggingFace Models - Example HF Models
- :ref:`LoRA <lora>` - :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>` - :ref:`PP <distributed_serving>`
* - :code:`Blip2ForConditionalGeneration` * - :code:`Blip2ForConditionalGeneration`
- BLIP-2 - BLIP-2
- Image\ :sup:`E` - T + I\ :sup:`E`
- :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc. - :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc.
- -
- ✅︎ - ✅︎
* - :code:`ChameleonForConditionalGeneration` * - :code:`ChameleonForConditionalGeneration`
- Chameleon - Chameleon
- Image - T + I
- :code:`facebook/chameleon-7b` etc. - :code:`facebook/chameleon-7b` etc.
- -
- ✅︎ - ✅︎
* - :code:`FuyuForCausalLM` * - :code:`FuyuForCausalLM`
- Fuyu - Fuyu
- Image - T + I
- :code:`adept/fuyu-8b` etc. - :code:`adept/fuyu-8b` etc.
- -
- ✅︎ - ✅︎
* - :code:`ChatGLMModel` * - :code:`ChatGLMModel`
- GLM-4V - GLM-4V
- Image - T + I
- :code:`THUDM/glm-4v-9b` etc. - :code:`THUDM/glm-4v-9b` etc.
- -
- ✅︎ - ✅︎
* - :code:`InternVLChatModel` * - :code:`InternVLChatModel`
- InternVL2 - InternVL2
- Image\ :sup:`E+` - T + I\ :sup:`E+`
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc. - :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
- -
- ✅︎ - ✅︎
* - :code:`LlavaForConditionalGeneration` * - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5 - LLaVA-1.5
- Image\ :sup:`E+` - T + I\ :sup:`E+`
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
- -
- ✅︎ - ✅︎
* - :code:`LlavaNextForConditionalGeneration` * - :code:`LlavaNextForConditionalGeneration`
- LLaVA-NeXT - LLaVA-NeXT
- Image\ :sup:`E+` - T + I\ :sup:`E+`
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
- -
- ✅︎ - ✅︎
* - :code:`LlavaNextVideoForConditionalGeneration` * - :code:`LlavaNextVideoForConditionalGeneration`
- LLaVA-NeXT-Video - LLaVA-NeXT-Video
- Video - T + V
- :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc.
- -
- ✅︎ - ✅︎
* - :code:`LlavaOnevisionForConditionalGeneration` * - :code:`LlavaOnevisionForConditionalGeneration`
- LLaVA-Onevision - LLaVA-Onevision
- Image\ :sup:`+` / Video - T + I\ :sup:`+` + V
- :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc.
- -
- ✅︎ - ✅︎
* - :code:`MiniCPMV` * - :code:`MiniCPMV`
- MiniCPM-V - MiniCPM-V
- Image\ :sup:`E+` - T + I\ :sup:`E+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
- ✅︎ - ✅︎
- ✅︎ - ✅︎
* - :code:`MllamaForConditionalGeneration` * - :code:`MllamaForConditionalGeneration`
- Llama 3.2 - Llama 3.2
- Image - T + I
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
- -
- -
@ -407,43 +414,43 @@ Text Generation
- ✅︎ - ✅︎
* - :code:`NVLM_D_Model` * - :code:`NVLM_D_Model`
- NVLM-D 1.0 - NVLM-D 1.0
- Image\ :sup:`E+` - T + I\ :sup:`E+`
- :code:`nvidia/NVLM-D-72B`, etc. - :code:`nvidia/NVLM-D-72B`, etc.
- -
- ✅︎ - ✅︎
* - :code:`PaliGemmaForConditionalGeneration` * - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma - PaliGemma
- Image\ :sup:`E` - T + I\ :sup:`E`
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
- -
- ✅︎ - ✅︎
* - :code:`Phi3VForCausalLM` * - :code:`Phi3VForCausalLM`
- Phi-3-Vision, Phi-3.5-Vision - Phi-3-Vision, Phi-3.5-Vision
- Image\ :sup:`E+` - T + I\ :sup:`E+`
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
- -
- ✅︎ - ✅︎
* - :code:`PixtralForConditionalGeneration` * - :code:`PixtralForConditionalGeneration`
- Pixtral - Pixtral
- Image\ :sup:`+` - T + I\ :sup:`+`
- :code:`mistralai/Pixtral-12B-2409` - :code:`mistralai/Pixtral-12B-2409`
- -
- ✅︎ - ✅︎
* - :code:`QWenLMHeadModel` * - :code:`QWenLMHeadModel`
- Qwen-VL - Qwen-VL
- Image\ :sup:`E+` - T + I\ :sup:`E+`
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
- -
- ✅︎ - ✅︎
* - :code:`Qwen2VLForConditionalGeneration` * - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL - Qwen2-VL
- Image\ :sup:`E+` / Video\ :sup:`+` - T + I\ :sup:`E+` + V\ :sup:`+`
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
- -
- ✅︎ - ✅︎
* - :code:`UltravoxModel` * - :code:`UltravoxModel`
- Ultravox - Ultravox
- Audio\ :sup:`E+` - T + A\ :sup:`E+`
- :code:`fixie-ai/ultravox-v0_3` - :code:`fixie-ai/ultravox-v0_3`
- -
- ✅︎ - ✅︎
@ -455,6 +462,26 @@ Text Generation
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
Multimodal Embedding
--------------------
.. list-table::
:widths: 25 25 15 25 5 5
:header-rows: 1
* - Architecture
- Models
- Inputs
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`Phi3VForCausalLM`
- Phi-3-Vision-based
- T + I
- :code:`TIGER-Lab/VLM2Vec-Full`
- 🚧
- ✅︎
---- ----
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.

View File

@ -0,0 +1,21 @@
from vllm import LLM
from vllm.assets.image import ImageAsset
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
prompt = "<|image_1|> Represent the given image with the following question: What is in the image" # noqa: E501
# Create an LLM.
llm = LLM(
model="TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
mm_processor_kwargs={"num_crops": 16},
)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = llm.encode({"prompt": prompt, "multi_modal_data": {"image": image}})
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 3072 floats

View File

@ -262,7 +262,7 @@ class HfRunner:
dtype: str = "half", dtype: str = "half",
*, *,
model_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False, is_sentence_transformer: bool = False,
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM, auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
postprocess_inputs: Callable[[BatchEncoding], postprocess_inputs: Callable[[BatchEncoding],
BatchEncoding] = identity, BatchEncoding] = identity,
@ -271,7 +271,7 @@ class HfRunner:
self.model_name = model_name self.model_name = model_name
if is_embedding_model: if is_sentence_transformer:
# Lazy init required for AMD CI # Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
self.model = self.wrap_device( self.model = self.wrap_device(
@ -307,17 +307,23 @@ class HfRunner:
self.postprocess_inputs = postprocess_inputs self.postprocess_inputs = postprocess_inputs
def generate( def get_inputs(
self, self,
prompts: List[str], prompts: List[str],
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None, videos: Optional[PromptVideoInput] = None,
**kwargs: Any, audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[List[int]], List[str]]]: ) -> List[BatchEncoding]:
if images: if images is not None:
assert len(prompts) == len(images) assert len(prompts) == len(images)
outputs: List[Tuple[List[List[int]], List[str]]] = [] if videos is not None:
assert len(prompts) == len(videos)
if audios is not None:
assert len(prompts) == len(audios)
all_inputs: List[BatchEncoding] = []
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = { processor_kwargs: Dict[str, Any] = {
"text": prompt, "text": prompt,
@ -327,10 +333,33 @@ class HfRunner:
processor_kwargs["images"] = images[i] processor_kwargs["images"] = images[i]
if videos is not None and videos[i] is not None: if videos is not None and videos[i] is not None:
processor_kwargs["videos"] = videos[i] processor_kwargs["videos"] = videos[i]
if audios is not None and audios[i] is not None:
audio, sr = audios[i]
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr
inputs = self.processor(**processor_kwargs) inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs) inputs = self.postprocess_inputs(inputs)
all_inputs.append(inputs)
return all_inputs
def generate(
self,
prompts: List[str],
images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
all_inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
outputs: List[Tuple[List[List[int]], List[str]]] = []
for inputs in all_inputs:
output_ids = self.model.generate( output_ids = self.model.generate(
**self.wrap_device(inputs, device=self.model.device.type), **self.wrap_device(inputs, device=self.model.device.type),
use_cache=True, use_cache=True,
@ -350,12 +379,16 @@ class HfRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts, outputs = self.generate(prompts,
do_sample=False, do_sample=False,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
images=images, images=images,
videos=videos,
audios=audios,
**kwargs) **kwargs)
return [(output_ids[0], output_str[0]) return [(output_ids[0], output_str[0])
@ -388,22 +421,16 @@ class HfRunner:
max_tokens: int, max_tokens: int,
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None, videos: Optional[List[np.ndarray]] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[List[torch.Tensor]]: ) -> List[List[torch.Tensor]]:
all_inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
all_logprobs: List[List[torch.Tensor]] = [] all_logprobs: List[List[torch.Tensor]] = []
for i, prompt in enumerate(prompts): for inputs in all_inputs:
processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
if videos is not None and videos[i] is not None:
processor_kwargs["videos"] = videos[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
output = self.model.generate( output = self.model.generate(
**self.wrap_device(inputs, device=self.model.device.type), **self.wrap_device(inputs, device=self.model.device.type),
use_cache=True, use_cache=True,
@ -475,28 +502,16 @@ class HfRunner:
videos: Optional[List[np.ndarray]] = None, videos: Optional[List[np.ndarray]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[TokensTextLogprobs]: ) -> List[TokensTextLogprobs]:
all_inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
all_logprobs: List[List[Dict[int, float]]] = [] all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = [] all_output_ids: List[List[int]] = []
all_output_strs: List[str] = [] all_output_strs: List[str] = []
for i, prompt in enumerate(prompts): for inputs in all_inputs:
processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
if audios is not None:
audio, sr = audios[i]
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr
if videos is not None:
processor_kwargs["videos"] = videos[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
output = self.model.generate( output = self.model.generate(
**self.wrap_device(inputs, device=self.model.device.type), **self.wrap_device(inputs, device=self.model.device.type),
use_cache=True, use_cache=True,
@ -632,20 +647,50 @@ class VllmRunner:
**kwargs, **kwargs,
) )
def generate( def get_inputs(
self, self,
prompts: List[str], prompts: List[str],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
) -> List[Tuple[List[List[int]], List[str]]]: videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[TextPrompt]:
if images is not None: if images is not None:
assert len(prompts) == len(images) assert len(prompts) == len(images)
if videos is not None:
assert len(prompts) == len(videos)
if audios is not None:
assert len(prompts) == len(audios)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts] inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None: if images is not None:
for i, image in enumerate(images): for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image} inputs[i]["multi_modal_data"] = {"image": image}
if videos is not None:
for i, video in enumerate(videos):
inputs[i]["multi_modal_data"] = {"video": video}
if audios is not None:
for i, audio in enumerate(audios):
inputs[i]["multi_modal_data"] = {"audio": audio}
return inputs
def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.generate(inputs, req_outputs = self.model.generate(inputs,
sampling_params=sampling_params) sampling_params=sampling_params)
@ -687,24 +732,10 @@ class VllmRunner:
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
) -> Union[List[TokensTextLogprobs], ) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]: List[TokensTextLogprobsPromptLogprobs]]:
if images is not None: inputs = self.get_inputs(prompts,
assert len(prompts) == len(images) images=images,
videos=videos,
if videos is not None: audios=audios)
assert len(prompts) == len(videos)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image}
if audios is not None:
for i, audio in enumerate(audios):
inputs[i]["multi_modal_data"] = {"audio": audio}
if videos is not None:
for i, video in enumerate(videos):
inputs[i]["multi_modal_data"] = {"video": video}
req_outputs = self.model.generate(inputs, req_outputs = self.model.generate(inputs,
sampling_params=sampling_params) sampling_params=sampling_params)
@ -741,9 +772,15 @@ class VllmRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images) outputs = self.generate(prompts,
greedy_params,
images=images,
videos=videos,
audios=audios)
return [(output_ids[0], output_str[0]) return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs] for output_ids, output_str in outputs]

View File

@ -1,10 +1,10 @@
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. """Compare the embedding outputs of HF and vLLM models.
Run `pytest tests/models/embedding/language/test_embedding.py`. Run `pytest tests/models/embedding/language/test_embedding.py`.
""" """
import pytest import pytest
import torch
import torch.nn.functional as F from ..utils import check_embeddings_close
MODELS = [ MODELS = [
"intfloat/e5-mistral-7b-instruct", "intfloat/e5-mistral-7b-instruct",
@ -12,14 +12,6 @@ MODELS = [
] ]
def compare_embeddings(embeddings1, embeddings2):
similarities = [
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0)
for e1, e2 in zip(embeddings1, embeddings2)
]
return similarities
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
def test_models( def test_models(
@ -37,15 +29,17 @@ def test_models(
# So we need to strip the input texts to avoid test failing. # So we need to strip the input texts to avoid test failing.
example_prompts = [str(s).strip() for s in example_prompts] example_prompts = [str(s).strip() for s in example_prompts]
with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model: with hf_runner(model, dtype=dtype,
is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts) hf_outputs = hf_model.encode(example_prompts)
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts) vllm_outputs = vllm_model.encode(example_prompts)
similarities = compare_embeddings(hf_outputs, vllm_outputs) check_embeddings_close(
all_similarities = torch.stack(similarities) embeddings_0_lst=hf_outputs,
tolerance = 1e-2 embeddings_1_lst=vllm_outputs,
assert torch.all((all_similarities <= 1.0 + tolerance) name_0="hf",
& (all_similarities >= 1.0 - tolerance) name_1="vllm",
), f"Not all values are within {tolerance} of 1.0" tol=1e-2,
)

View File

@ -0,0 +1,29 @@
from typing import List, Sequence
import torch
import torch.nn.functional as F
def check_embeddings_close(
*,
embeddings_0_lst: Sequence[List[float]],
embeddings_1_lst: Sequence[List[float]],
name_0: str,
name_1: str,
tol: float = 1e-3,
) -> None:
assert len(embeddings_0_lst) == len(embeddings_1_lst)
for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
zip(embeddings_0_lst, embeddings_1_lst)):
assert len(embeddings_0) == len(embeddings_1)
sim = F.cosine_similarity(torch.tensor(embeddings_0),
torch.tensor(embeddings_1),
dim=0)
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{embeddings_0!r}"
f"\n{name_1}:\t{embeddings_1!r}")
assert sim >= 1 - tol, fail_msg

View File

@ -0,0 +1,62 @@
import pytest
import torch.nn.functional as F
from ....conftest import IMAGE_ASSETS
from ..utils import check_embeddings_close
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign", # noqa: E501
"cherry_blossom":
"<|image_1|> Represent the given image with the following question: What is in the image", # noqa: E501
})
MODELS = ["TIGER-Lab/VLM2Vec-Full"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=2,
dtype=dtype,
enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(model, dtype=dtype) as hf_model:
all_inputs = hf_model.get_inputs(example_prompts)
all_outputs = []
for inputs in all_inputs:
# Based on: https://github.com/TIGER-AI-Lab/VLM2Vec/blob/db3b951bccabba220c1f53ab46a734e50dd2fc08/src/model.py
outputs = hf_model.model(
**hf_model.wrap_device(inputs,
device=hf_model.model.device.type),
return_dict=True,
output_hidden_states=True,
)
last_hidden_state = outputs.hidden_states[-1][0]
reps = last_hidden_state[inputs.attention_mask[0].sum() - 1]
pooled_output = F.normalize(reps, p=2, dim=-1)
all_outputs.append(pooled_output.tolist())
hf_outputs = all_outputs
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Union
import torch import torch
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors

View File

@ -237,7 +237,16 @@ class ModelConfig:
def _verify_embedding_mode(self) -> None: def _verify_embedding_mode(self) -> None:
architectures = getattr(self.hf_config, "architectures", []) architectures = getattr(self.hf_config, "architectures", [])
self.embedding_mode = ModelRegistry.is_embedding_model(architectures)
# TODO: Allow the same model architecture to be specified as either
# generation or embedding model
if "Phi3VForCausalLM" in architectures:
# Match both remote and local names
embedding_mode = "/VLM2Vec" in self.model
else:
embedding_mode = ModelRegistry.is_embedding_model(architectures)
self.embedding_mode = embedding_mode
def _parse_quant_hf_config(self): def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None) quant_cfg = getattr(self.hf_config, "quantization_config", None)

View File

@ -31,14 +31,16 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
@ -461,3 +463,50 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None), if self.config.tie_word_embeddings else None),
) )
loader.load_weights(weights) loader.load_weights(weights)
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = Gemma2Model(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)

View File

@ -1,57 +0,0 @@
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .gemma2 import Gemma2Model
from .interfaces import SupportsPP
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
"""A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = Gemma2Model(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)

View File

@ -38,6 +38,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale) get_compressed_tensors_cache_scale)
@ -47,8 +48,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_hip from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
@ -615,3 +617,52 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
name = name.replace(item, mapping[item]) name = name.replace(item, mapping[item])
return name, loaded_weight return name, loaded_weight
class LlamaEmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)

View File

@ -1,59 +0,0 @@
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsPP
from .llama import LlamaModel
class LlamaEmbeddingModel(nn.Module, SupportsPP):
"""A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)

View File

@ -29,14 +29,18 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
@ -289,10 +293,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
dim=2).reshape(num_images, -1, hid_dim) dim=2).reshape(num_images, -1, hid_dim)
return image_features_hd_newline return image_features_hd_newline
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
@ -385,23 +385,28 @@ def dummy_data_for_phi3v(ctx: InputContext,
return seq_data, mm_data return seq_data, mm_data
# Reserve this function to also handle placeholders for additional images
# [ref: PR #5820]
@lru_cache @lru_cache
def _get_image_placeholder_token_ids(model_config: ModelConfig, def _get_image_placeholder_token_id_candidates(
idx: int) -> List[int]: model_config: ModelConfig,
idx: int,
) -> List[List[int]]:
assert idx > 0 assert idx > 0
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
# This is used when the image token is at the start of the string
start_candidate = tokenizer.encode(f"<|image_{idx}|>",
add_special_tokens=False)
# This is used when the image token is in the middle of the string
# We need to get the token for "<", not "▁<" # We need to get the token for "<", not "▁<"
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
a_token_id, = tokenizer.encode("a", add_special_tokens=False) a_token_id, = tokenizer.encode("a", add_special_tokens=False)
a_token_id_, *image_placeholder_token_ids = tokenizer.encode( a_token_id_, *middle_candidate = tokenizer.encode(f"a<|image_{idx}|>",
f"a<|image_{idx}|>", add_special_tokens=False) add_special_tokens=False)
assert a_token_id == a_token_id_ assert a_token_id == a_token_id_
return image_placeholder_token_ids return [start_candidate, middle_candidate]
def input_processor_for_phi3v(ctx: InputContext, def input_processor_for_phi3v(ctx: InputContext,
@ -461,16 +466,20 @@ def input_processor_for_phi3v(ctx: InputContext,
prompt_token_ids = llm_inputs["prompt_token_ids"].copy() prompt_token_ids = llm_inputs["prompt_token_ids"].copy()
# masked place_holder with image token id print("prompt_token_ids (old)", prompt_token_ids)
# masked placeholder with image token id
for idx in image_idx: for idx in image_idx:
image_token_ids = _get_image_placeholder_token_ids(model_config, candidates = _get_image_placeholder_token_id_candidates(model_config,
idx=idx) idx=idx)
for i in range(len(prompt_token_ids) - len(image_token_ids) + 1):
if prompt_token_ids[i:i + len(image_token_ids)] == image_token_ids: for candidate in candidates:
prompt_token_ids[i:i + len(image_token_ids)] = [ for i in range(len(prompt_token_ids) - len(candidate) + 1):
_IMAGE_TOKEN_ID if prompt_token_ids[i:i + len(candidate)] == candidate:
] * len(image_token_ids) prompt_token_ids[i:i +
break len(candidate)] = ([_IMAGE_TOKEN_ID] *
len(candidate))
break
# merge consecutive tag ids # merge consecutive tag ids
merged_token_ids: List[int] = [] merged_token_ids: List[int] = []
@ -520,12 +529,23 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.image_token_id = _IMAGE_TOKEN_ID self.image_token_id = _IMAGE_TOKEN_ID
# TODO: Optionally initializes this for supporting embeddings. self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
# TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(config) self.vision_embed_tokens = Phi3HDImageEmbedding(config)
self.language_model = LlamaForCausalLM(config, cache_config, self.language_model = LlamaForCausalLM(config, cache_config,
quant_config) quant_config)
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@ -649,8 +669,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.embed_tokens(input_ids)
input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.image_token_id) self.image_token_id)
@ -682,13 +701,27 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
"model.vision_embed_tokens.wte": "embed_tokens",
"model.vision_embed_tokens.": "vision_embed_tokens.", "model.vision_embed_tokens.": "vision_embed_tokens.",
"lm_head.": "language_model.lm_head.", "lm_head.": "language_model.lm_head.",
"model.": "language_model.model.", "model.": "language_model.model.",
}) })
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=hf_to_vllm_mapper) autoloaded_weights = loader.load_weights(weights,
mapper=hf_to_vllm_mapper)
# The HF config doesn't specify whether these are tied,
# so we detect it this way
if "embed_tokens" not in autoloaded_weights:
self.embed_tokens = self.language_model.model.embed_tokens

View File

@ -86,9 +86,12 @@ _TEXT_GENERATION_MODELS = {
} }
_EMBEDDING_MODELS = { _EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), # [Text-only]
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"), # [Multimodal]
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
} }
_MULTIMODAL_MODELS = { _MULTIMODAL_MODELS = {

View File

@ -124,7 +124,7 @@ class AutoWeightsLoader:
base_prefix: str, base_prefix: str,
param: nn.Parameter, param: nn.Parameter,
weights: Iterable[Tuple[str, torch.Tensor]], weights: Iterable[Tuple[str, torch.Tensor]],
) -> None: ) -> Iterable[str]:
for weight_name, weight_data in weights: for weight_name, weight_data in weights:
weight_qualname = self._get_qualname(base_prefix, weight_name) weight_qualname = self._get_qualname(base_prefix, weight_name)
@ -143,12 +143,14 @@ class AutoWeightsLoader:
default_weight_loader) default_weight_loader)
weight_loader(param, weight_data) weight_loader(param, weight_data)
yield weight_qualname
def _load_module( def _load_module(
self, self,
base_prefix: str, base_prefix: str,
module: nn.Module, module: nn.Module,
weights: Iterable[Tuple[str, torch.Tensor]], weights: Iterable[Tuple[str, torch.Tensor]],
) -> None: ) -> Iterable[str]:
if isinstance(module, PPMissingLayer): if isinstance(module, PPMissingLayer):
return return
@ -170,14 +172,16 @@ class AutoWeightsLoader:
continue continue
if child_prefix in child_modules: if child_prefix in child_modules:
self._load_module(prefix, child_modules[child_prefix], yield from self._load_module(prefix,
child_weights) child_modules[child_prefix],
child_weights)
elif child_prefix in child_params: elif child_prefix in child_params:
self._load_param(prefix, child_params[child_prefix], yield from self._load_param(prefix, child_params[child_prefix],
child_weights) child_weights)
else: else:
if not self._can_ignore_unexpected(prefix): if not self._can_ignore_unexpected(prefix):
msg = f"There is no module or parameter named '{prefix}'" msg = (f"There is no module or parameter named '{prefix}' "
f"in {type(self.module).__name__}")
raise ValueError(msg) raise ValueError(msg)
def load_weights( def load_weights(
@ -185,11 +189,12 @@ class AutoWeightsLoader:
weights: Iterable[Tuple[str, torch.Tensor]], weights: Iterable[Tuple[str, torch.Tensor]],
*, *,
mapper: Optional[WeightsMapper] = None, mapper: Optional[WeightsMapper] = None,
) -> None: ) -> List[str]:
if mapper is not None: if mapper is not None:
weights = mapper.apply(weights) weights = mapper.apply(weights)
self._load_module("", self.module, weights) autoloaded_weights = list(self._load_module("", self.module, weights))
return autoloaded_weights
def init_vllm_registered_model( def init_vllm_registered_model(