[Model] VLM2Vec, the first multimodal embedding model in vLLM (#9303)
This commit is contained in:
parent
7e7eae338d
commit
7abba39ee6
@ -3,7 +3,7 @@
|
||||
Supported Models
|
||||
================
|
||||
|
||||
vLLM supports a variety of generative Transformer models in `HuggingFace Transformers <https://huggingface.co/models>`_.
|
||||
vLLM supports a variety of generative Transformer models in `HuggingFace (HF) Transformers <https://huggingface.co/models>`_.
|
||||
The following is the list of model architectures that are currently supported by vLLM.
|
||||
Alongside each architecture, we include some popular models that use it.
|
||||
|
||||
@ -19,7 +19,7 @@ Text Generation
|
||||
|
||||
* - Architecture
|
||||
- Models
|
||||
- Example HuggingFace Models
|
||||
- Example HF Models
|
||||
- :ref:`LoRA <lora>`
|
||||
- :ref:`PP <distributed_serving>`
|
||||
* - :code:`AquilaForCausalLM`
|
||||
@ -280,7 +280,7 @@ Text Embedding
|
||||
|
||||
* - Architecture
|
||||
- Models
|
||||
- Example HuggingFace Models
|
||||
- Example HF Models
|
||||
- :ref:`LoRA <lora>`
|
||||
- :ref:`PP <distributed_serving>`
|
||||
* - :code:`Gemma2Model`
|
||||
@ -303,7 +303,7 @@ Reward Modeling
|
||||
|
||||
* - Architecture
|
||||
- Models
|
||||
- Example HuggingFace Models
|
||||
- Example HF Models
|
||||
- :ref:`LoRA <lora>`
|
||||
- :ref:`PP <distributed_serving>`
|
||||
* - :code:`Qwen2ForRewardModel`
|
||||
@ -316,7 +316,14 @@ Reward Modeling
|
||||
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.
|
||||
|
||||
Multimodal Language Models
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The following modalities are supported depending on the model:
|
||||
|
||||
- **T**\ ext
|
||||
- **I**\ mage
|
||||
- **V**\ ideo
|
||||
- **A**\ udio
|
||||
|
||||
.. _supported_vlms:
|
||||
|
||||
@ -324,78 +331,78 @@ Text Generation
|
||||
---------------
|
||||
|
||||
.. list-table::
|
||||
:widths: 25 25 25 25 5 5
|
||||
:widths: 25 25 15 25 5 5
|
||||
:header-rows: 1
|
||||
|
||||
* - Architecture
|
||||
- Models
|
||||
- Modalities
|
||||
- Example HuggingFace Models
|
||||
- Inputs
|
||||
- Example HF Models
|
||||
- :ref:`LoRA <lora>`
|
||||
- :ref:`PP <distributed_serving>`
|
||||
* - :code:`Blip2ForConditionalGeneration`
|
||||
- BLIP-2
|
||||
- Image\ :sup:`E`
|
||||
- T + I\ :sup:`E`
|
||||
- :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`ChameleonForConditionalGeneration`
|
||||
- Chameleon
|
||||
- Image
|
||||
- T + I
|
||||
- :code:`facebook/chameleon-7b` etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`FuyuForCausalLM`
|
||||
- Fuyu
|
||||
- Image
|
||||
- T + I
|
||||
- :code:`adept/fuyu-8b` etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`ChatGLMModel`
|
||||
- GLM-4V
|
||||
- Image
|
||||
- T + I
|
||||
- :code:`THUDM/glm-4v-9b` etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`InternVLChatModel`
|
||||
- InternVL2
|
||||
- Image\ :sup:`E+`
|
||||
- T + I\ :sup:`E+`
|
||||
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`LlavaForConditionalGeneration`
|
||||
- LLaVA-1.5
|
||||
- Image\ :sup:`E+`
|
||||
- T + I\ :sup:`E+`
|
||||
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`LlavaNextForConditionalGeneration`
|
||||
- LLaVA-NeXT
|
||||
- Image\ :sup:`E+`
|
||||
- T + I\ :sup:`E+`
|
||||
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`LlavaNextVideoForConditionalGeneration`
|
||||
- LLaVA-NeXT-Video
|
||||
- Video
|
||||
- T + V
|
||||
- :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`LlavaOnevisionForConditionalGeneration`
|
||||
- LLaVA-Onevision
|
||||
- Image\ :sup:`+` / Video
|
||||
- T + I\ :sup:`+` + V
|
||||
- :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`MiniCPMV`
|
||||
- MiniCPM-V
|
||||
- Image\ :sup:`E+`
|
||||
- T + I\ :sup:`E+`
|
||||
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`MllamaForConditionalGeneration`
|
||||
- Llama 3.2
|
||||
- Image
|
||||
- T + I
|
||||
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
|
||||
-
|
||||
-
|
||||
@ -407,43 +414,43 @@ Text Generation
|
||||
- ✅︎
|
||||
* - :code:`NVLM_D_Model`
|
||||
- NVLM-D 1.0
|
||||
- Image\ :sup:`E+`
|
||||
- T + I\ :sup:`E+`
|
||||
- :code:`nvidia/NVLM-D-72B`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`PaliGemmaForConditionalGeneration`
|
||||
- PaliGemma
|
||||
- Image\ :sup:`E`
|
||||
- T + I\ :sup:`E`
|
||||
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`Phi3VForCausalLM`
|
||||
- Phi-3-Vision, Phi-3.5-Vision
|
||||
- Image\ :sup:`E+`
|
||||
- T + I\ :sup:`E+`
|
||||
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`PixtralForConditionalGeneration`
|
||||
- Pixtral
|
||||
- Image\ :sup:`+`
|
||||
- T + I\ :sup:`+`
|
||||
- :code:`mistralai/Pixtral-12B-2409`
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`QWenLMHeadModel`
|
||||
- Qwen-VL
|
||||
- Image\ :sup:`E+`
|
||||
- T + I\ :sup:`E+`
|
||||
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`Qwen2VLForConditionalGeneration`
|
||||
- Qwen2-VL
|
||||
- Image\ :sup:`E+` / Video\ :sup:`+`
|
||||
- T + I\ :sup:`E+` + V\ :sup:`+`
|
||||
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`UltravoxModel`
|
||||
- Ultravox
|
||||
- Audio\ :sup:`E+`
|
||||
- T + A\ :sup:`E+`
|
||||
- :code:`fixie-ai/ultravox-v0_3`
|
||||
-
|
||||
- ✅︎
|
||||
@ -455,6 +462,26 @@ Text Generation
|
||||
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
||||
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
||||
|
||||
Multimodal Embedding
|
||||
--------------------
|
||||
|
||||
.. list-table::
|
||||
:widths: 25 25 15 25 5 5
|
||||
:header-rows: 1
|
||||
|
||||
* - Architecture
|
||||
- Models
|
||||
- Inputs
|
||||
- Example HF Models
|
||||
- :ref:`LoRA <lora>`
|
||||
- :ref:`PP <distributed_serving>`
|
||||
* - :code:`Phi3VForCausalLM`
|
||||
- Phi-3-Vision-based
|
||||
- T + I
|
||||
- :code:`TIGER-Lab/VLM2Vec-Full`
|
||||
- 🚧
|
||||
- ✅︎
|
||||
|
||||
----
|
||||
|
||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||
|
21
examples/offline_inference_vision_language_embedding.py
Normal file
21
examples/offline_inference_vision_language_embedding.py
Normal file
@ -0,0 +1,21 @@
|
||||
from vllm import LLM
|
||||
from vllm.assets.image import ImageAsset
|
||||
|
||||
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
|
||||
prompt = "<|image_1|> Represent the given image with the following question: What is in the image" # noqa: E501
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="TIGER-Lab/VLM2Vec-Full",
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
mm_processor_kwargs={"num_crops": 16},
|
||||
)
|
||||
|
||||
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||
outputs = llm.encode({"prompt": prompt, "multi_modal_data": {"image": image}})
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
print(output.outputs.embedding) # list of 3072 floats
|
@ -262,7 +262,7 @@ class HfRunner:
|
||||
dtype: str = "half",
|
||||
*,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
is_embedding_model: bool = False,
|
||||
is_sentence_transformer: bool = False,
|
||||
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
||||
postprocess_inputs: Callable[[BatchEncoding],
|
||||
BatchEncoding] = identity,
|
||||
@ -271,7 +271,7 @@ class HfRunner:
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
if is_embedding_model:
|
||||
if is_sentence_transformer:
|
||||
# Lazy init required for AMD CI
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = self.wrap_device(
|
||||
@ -307,17 +307,23 @@ class HfRunner:
|
||||
|
||||
self.postprocess_inputs = postprocess_inputs
|
||||
|
||||
def generate(
|
||||
def get_inputs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[List[np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
if images:
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> List[BatchEncoding]:
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
|
||||
outputs: List[Tuple[List[List[int]], List[str]]] = []
|
||||
if videos is not None:
|
||||
assert len(prompts) == len(videos)
|
||||
|
||||
if audios is not None:
|
||||
assert len(prompts) == len(audios)
|
||||
|
||||
all_inputs: List[BatchEncoding] = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
processor_kwargs: Dict[str, Any] = {
|
||||
"text": prompt,
|
||||
@ -327,10 +333,33 @@ class HfRunner:
|
||||
processor_kwargs["images"] = images[i]
|
||||
if videos is not None and videos[i] is not None:
|
||||
processor_kwargs["videos"] = videos[i]
|
||||
if audios is not None and audios[i] is not None:
|
||||
audio, sr = audios[i]
|
||||
processor_kwargs["audio"] = audio
|
||||
processor_kwargs["sampling_rate"] = sr
|
||||
|
||||
inputs = self.processor(**processor_kwargs)
|
||||
inputs = self.postprocess_inputs(inputs)
|
||||
|
||||
all_inputs.append(inputs)
|
||||
|
||||
return all_inputs
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[List[np.ndarray]] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
all_inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
outputs: List[Tuple[List[List[int]], List[str]]] = []
|
||||
for inputs in all_inputs:
|
||||
output_ids = self.model.generate(
|
||||
**self.wrap_device(inputs, device=self.model.device.type),
|
||||
use_cache=True,
|
||||
@ -350,12 +379,16 @@ class HfRunner:
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[List[np.ndarray]] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
outputs = self.generate(prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
**kwargs)
|
||||
|
||||
return [(output_ids[0], output_str[0])
|
||||
@ -388,22 +421,16 @@ class HfRunner:
|
||||
max_tokens: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[List[np.ndarray]] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[List[torch.Tensor]]:
|
||||
all_inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
all_logprobs: List[List[torch.Tensor]] = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
processor_kwargs: Dict[str, Any] = {
|
||||
"text": prompt,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
if images is not None and images[i] is not None:
|
||||
processor_kwargs["images"] = images[i]
|
||||
if videos is not None and videos[i] is not None:
|
||||
processor_kwargs["videos"] = videos[i]
|
||||
|
||||
inputs = self.processor(**processor_kwargs)
|
||||
inputs = self.postprocess_inputs(inputs)
|
||||
|
||||
for inputs in all_inputs:
|
||||
output = self.model.generate(
|
||||
**self.wrap_device(inputs, device=self.model.device.type),
|
||||
use_cache=True,
|
||||
@ -475,28 +502,16 @@ class HfRunner:
|
||||
videos: Optional[List[np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[TokensTextLogprobs]:
|
||||
all_inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
all_logprobs: List[List[Dict[int, float]]] = []
|
||||
all_output_ids: List[List[int]] = []
|
||||
all_output_strs: List[str] = []
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
processor_kwargs: Dict[str, Any] = {
|
||||
"text": prompt,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
if images is not None and images[i] is not None:
|
||||
processor_kwargs["images"] = images[i]
|
||||
|
||||
if audios is not None:
|
||||
audio, sr = audios[i]
|
||||
processor_kwargs["audio"] = audio
|
||||
processor_kwargs["sampling_rate"] = sr
|
||||
|
||||
if videos is not None:
|
||||
processor_kwargs["videos"] = videos[i]
|
||||
inputs = self.processor(**processor_kwargs)
|
||||
inputs = self.postprocess_inputs(inputs)
|
||||
|
||||
for inputs in all_inputs:
|
||||
output = self.model.generate(
|
||||
**self.wrap_device(inputs, device=self.model.device.type),
|
||||
use_cache=True,
|
||||
@ -632,20 +647,50 @@ class VllmRunner:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def generate(
|
||||
def get_inputs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> List[TextPrompt]:
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
|
||||
if videos is not None:
|
||||
assert len(prompts) == len(videos)
|
||||
|
||||
if audios is not None:
|
||||
assert len(prompts) == len(audios)
|
||||
|
||||
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
inputs[i]["multi_modal_data"] = {"image": image}
|
||||
|
||||
if videos is not None:
|
||||
for i, video in enumerate(videos):
|
||||
inputs[i]["multi_modal_data"] = {"video": video}
|
||||
|
||||
if audios is not None:
|
||||
for i, audio in enumerate(audios):
|
||||
inputs[i]["multi_modal_data"] = {"audio": audio}
|
||||
|
||||
return inputs
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
req_outputs = self.model.generate(inputs,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
@ -687,24 +732,10 @@ class VllmRunner:
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
) -> Union[List[TokensTextLogprobs],
|
||||
List[TokensTextLogprobsPromptLogprobs]]:
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
|
||||
if videos is not None:
|
||||
assert len(prompts) == len(videos)
|
||||
|
||||
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
inputs[i]["multi_modal_data"] = {"image": image}
|
||||
|
||||
if audios is not None:
|
||||
for i, audio in enumerate(audios):
|
||||
inputs[i]["multi_modal_data"] = {"audio": audio}
|
||||
|
||||
if videos is not None:
|
||||
for i, video in enumerate(videos):
|
||||
inputs[i]["multi_modal_data"] = {"video": video}
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
req_outputs = self.model.generate(inputs,
|
||||
sampling_params=sampling_params)
|
||||
@ -741,9 +772,15 @@ class VllmRunner:
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs = self.generate(prompts, greedy_params, images=images)
|
||||
outputs = self.generate(prompts,
|
||||
greedy_params,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
return [(output_ids[0], output_str[0])
|
||||
for output_ids, output_str in outputs]
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
|
||||
"""Compare the embedding outputs of HF and vLLM models.
|
||||
|
||||
Run `pytest tests/models/embedding/language/test_embedding.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import check_embeddings_close
|
||||
|
||||
MODELS = [
|
||||
"intfloat/e5-mistral-7b-instruct",
|
||||
@ -12,14 +12,6 @@ MODELS = [
|
||||
]
|
||||
|
||||
|
||||
def compare_embeddings(embeddings1, embeddings2):
|
||||
similarities = [
|
||||
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0)
|
||||
for e1, e2 in zip(embeddings1, embeddings2)
|
||||
]
|
||||
return similarities
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models(
|
||||
@ -37,15 +29,17 @@ def test_models(
|
||||
# So we need to strip the input texts to avoid test failing.
|
||||
example_prompts = [str(s).strip() for s in example_prompts]
|
||||
|
||||
with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model:
|
||||
with hf_runner(model, dtype=dtype,
|
||||
is_sentence_transformer=True) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
|
||||
similarities = compare_embeddings(hf_outputs, vllm_outputs)
|
||||
all_similarities = torch.stack(similarities)
|
||||
tolerance = 1e-2
|
||||
assert torch.all((all_similarities <= 1.0 + tolerance)
|
||||
& (all_similarities >= 1.0 - tolerance)
|
||||
), f"Not all values are within {tolerance} of 1.0"
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
29
tests/models/embedding/utils.py
Normal file
29
tests/models/embedding/utils.py
Normal file
@ -0,0 +1,29 @@
|
||||
from typing import List, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def check_embeddings_close(
|
||||
*,
|
||||
embeddings_0_lst: Sequence[List[float]],
|
||||
embeddings_1_lst: Sequence[List[float]],
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
tol: float = 1e-3,
|
||||
) -> None:
|
||||
assert len(embeddings_0_lst) == len(embeddings_1_lst)
|
||||
|
||||
for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
|
||||
zip(embeddings_0_lst, embeddings_1_lst)):
|
||||
assert len(embeddings_0) == len(embeddings_1)
|
||||
|
||||
sim = F.cosine_similarity(torch.tensor(embeddings_0),
|
||||
torch.tensor(embeddings_1),
|
||||
dim=0)
|
||||
|
||||
fail_msg = (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{embeddings_0!r}"
|
||||
f"\n{name_1}:\t{embeddings_1!r}")
|
||||
|
||||
assert sim >= 1 - tol, fail_msg
|
0
tests/models/embedding/vision_language/__init__.py
Normal file
0
tests/models/embedding/vision_language/__init__.py
Normal file
62
tests/models/embedding/vision_language/test_phi3v.py
Normal file
62
tests/models/embedding/vision_language/test_phi3v.py
Normal file
@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ....conftest import IMAGE_ASSETS
|
||||
from ..utils import check_embeddings_close
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign", # noqa: E501
|
||||
"cherry_blossom":
|
||||
"<|image_1|> Represent the given image with the following question: What is in the image", # noqa: E501
|
||||
})
|
||||
|
||||
MODELS = ["TIGER-Lab/VLM2Vec-Full"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
with vllm_runner(model,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
dtype=dtype,
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
all_inputs = hf_model.get_inputs(example_prompts)
|
||||
|
||||
all_outputs = []
|
||||
for inputs in all_inputs:
|
||||
# Based on: https://github.com/TIGER-AI-Lab/VLM2Vec/blob/db3b951bccabba220c1f53ab46a734e50dd2fc08/src/model.py
|
||||
outputs = hf_model.model(
|
||||
**hf_model.wrap_device(inputs,
|
||||
device=hf_model.model.device.type),
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
last_hidden_state = outputs.hidden_states[-1][0]
|
||||
reps = last_hidden_state[inputs.attention_mask[0].sum() - 1]
|
||||
pooled_output = F.normalize(reps, p=2, dim=-1)
|
||||
|
||||
all_outputs.append(pooled_output.tolist())
|
||||
|
||||
hf_outputs = all_outputs
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
@ -3,7 +3,7 @@ from typing import List, Optional, Union
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel
|
||||
from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
|
@ -237,7 +237,16 @@ class ModelConfig:
|
||||
|
||||
def _verify_embedding_mode(self) -> None:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
self.embedding_mode = ModelRegistry.is_embedding_model(architectures)
|
||||
|
||||
# TODO: Allow the same model architecture to be specified as either
|
||||
# generation or embedding model
|
||||
if "Phi3VForCausalLM" in architectures:
|
||||
# Match both remote and local names
|
||||
embedding_mode = "/VLM2Vec" in self.model
|
||||
else:
|
||||
embedding_mode = ModelRegistry.is_embedding_model(architectures)
|
||||
|
||||
self.embedding_mode = embedding_mode
|
||||
|
||||
def _parse_quant_hf_config(self):
|
||||
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||
|
@ -31,14 +31,16 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
@ -461,3 +463,50 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
loader.load_weights(weights)
|
||||
|
||||
|
||||
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
|
||||
"""
|
||||
A model that uses Gemma2 with additional embedding functionalities.
|
||||
|
||||
This class encapsulates the Gemma2Model and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
model: An instance of Gemma2Model used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model = Gemma2Model(**kwargs)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
return self.model(input_ids, positions, kv_caches, attn_metadata,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
self.model.load_weights(weights)
|
||||
|
@ -1,57 +0,0 @@
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .gemma2 import Gemma2Model
|
||||
from .interfaces import SupportsPP
|
||||
|
||||
|
||||
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
|
||||
"""A model that uses Gemma2 with additional embedding functionalities.
|
||||
|
||||
This class encapsulates the Gemma2Model and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
model: An instance of Gemma2Model used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = Gemma2Model(**kwargs)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
return self.model(input_ids, positions, kv_caches, attn_metadata,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
self.model.load_weights(weights)
|
@ -38,6 +38,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
get_compressed_tensors_cache_scale)
|
||||
@ -47,8 +48,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
@ -615,3 +617,52 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
|
||||
class LlamaEmbeddingModel(nn.Module, SupportsPP):
|
||||
"""
|
||||
A model that uses Llama with additional embedding functionalities.
|
||||
|
||||
This class encapsulates the LlamaModel and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
model: An instance of LlamaModel used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model = LlamaModel(**kwargs)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
return self.model(input_ids, positions, kv_caches, attn_metadata,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
self.model.load_weights(weights)
|
||||
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
@ -1,59 +0,0 @@
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .llama import LlamaModel
|
||||
|
||||
|
||||
class LlamaEmbeddingModel(nn.Module, SupportsPP):
|
||||
"""A model that uses Llama with additional embedding functionalities.
|
||||
|
||||
This class encapsulates the LlamaModel and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
model: An instance of LlamaModel used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = LlamaModel(**kwargs)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
return self.model(input_ids, positions, kv_caches, attn_metadata,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
self.model.load_weights(weights)
|
||||
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
@ -29,14 +29,18 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
@ -289,10 +293,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
dim=2).reshape(num_images, -1, hid_dim)
|
||||
return image_features_hd_newline
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
|
||||
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
||||
@ -385,23 +385,28 @@ def dummy_data_for_phi3v(ctx: InputContext,
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
# Reserve this function to also handle placeholders for additional images
|
||||
# [ref: PR #5820]
|
||||
@lru_cache
|
||||
def _get_image_placeholder_token_ids(model_config: ModelConfig,
|
||||
idx: int) -> List[int]:
|
||||
def _get_image_placeholder_token_id_candidates(
|
||||
model_config: ModelConfig,
|
||||
idx: int,
|
||||
) -> List[List[int]]:
|
||||
assert idx > 0
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
# This is used when the image token is at the start of the string
|
||||
start_candidate = tokenizer.encode(f"<|image_{idx}|>",
|
||||
add_special_tokens=False)
|
||||
|
||||
# This is used when the image token is in the middle of the string
|
||||
# We need to get the token for "<", not "▁<"
|
||||
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
|
||||
a_token_id, = tokenizer.encode("a", add_special_tokens=False)
|
||||
a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
|
||||
f"a<|image_{idx}|>", add_special_tokens=False)
|
||||
a_token_id_, *middle_candidate = tokenizer.encode(f"a<|image_{idx}|>",
|
||||
add_special_tokens=False)
|
||||
assert a_token_id == a_token_id_
|
||||
|
||||
return image_placeholder_token_ids
|
||||
return [start_candidate, middle_candidate]
|
||||
|
||||
|
||||
def input_processor_for_phi3v(ctx: InputContext,
|
||||
@ -461,16 +466,20 @@ def input_processor_for_phi3v(ctx: InputContext,
|
||||
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"].copy()
|
||||
|
||||
# masked place_holder with image token id
|
||||
print("prompt_token_ids (old)", prompt_token_ids)
|
||||
|
||||
# masked placeholder with image token id
|
||||
for idx in image_idx:
|
||||
image_token_ids = _get_image_placeholder_token_ids(model_config,
|
||||
idx=idx)
|
||||
for i in range(len(prompt_token_ids) - len(image_token_ids) + 1):
|
||||
if prompt_token_ids[i:i + len(image_token_ids)] == image_token_ids:
|
||||
prompt_token_ids[i:i + len(image_token_ids)] = [
|
||||
_IMAGE_TOKEN_ID
|
||||
] * len(image_token_ids)
|
||||
break
|
||||
candidates = _get_image_placeholder_token_id_candidates(model_config,
|
||||
idx=idx)
|
||||
|
||||
for candidate in candidates:
|
||||
for i in range(len(prompt_token_ids) - len(candidate) + 1):
|
||||
if prompt_token_ids[i:i + len(candidate)] == candidate:
|
||||
prompt_token_ids[i:i +
|
||||
len(candidate)] = ([_IMAGE_TOKEN_ID] *
|
||||
len(candidate))
|
||||
break
|
||||
|
||||
# merge consecutive tag ids
|
||||
merged_token_ids: List[int] = []
|
||||
@ -520,12 +529,23 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.multimodal_config = multimodal_config
|
||||
self.image_token_id = _IMAGE_TOKEN_ID
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# TODO: Optionally initializes this for supporting input embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(config)
|
||||
|
||||
self.language_model = LlamaForCausalLM(config, cache_config,
|
||||
quant_config)
|
||||
|
||||
# The same model class supports both language generation and embedding
|
||||
# because the architecture name is the same
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -649,8 +669,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.image_token_id)
|
||||
@ -682,13 +701,27 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.vision_embed_tokens.wte": "embed_tokens",
|
||||
"model.vision_embed_tokens.": "vision_embed_tokens.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.": "language_model.model.",
|
||||
})
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights, mapper=hf_to_vllm_mapper)
|
||||
autoloaded_weights = loader.load_weights(weights,
|
||||
mapper=hf_to_vllm_mapper)
|
||||
|
||||
# The HF config doesn't specify whether these are tied,
|
||||
# so we detect it this way
|
||||
if "embed_tokens" not in autoloaded_weights:
|
||||
self.embed_tokens = self.language_model.model.embed_tokens
|
||||
|
@ -86,9 +86,12 @@ _TEXT_GENERATION_MODELS = {
|
||||
}
|
||||
|
||||
_EMBEDDING_MODELS = {
|
||||
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
|
||||
# [Text-only]
|
||||
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
||||
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"),
|
||||
# [Multimodal]
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
|
@ -124,7 +124,7 @@ class AutoWeightsLoader:
|
||||
base_prefix: str,
|
||||
param: nn.Parameter,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
) -> None:
|
||||
) -> Iterable[str]:
|
||||
for weight_name, weight_data in weights:
|
||||
weight_qualname = self._get_qualname(base_prefix, weight_name)
|
||||
|
||||
@ -143,12 +143,14 @@ class AutoWeightsLoader:
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight_data)
|
||||
|
||||
yield weight_qualname
|
||||
|
||||
def _load_module(
|
||||
self,
|
||||
base_prefix: str,
|
||||
module: nn.Module,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
) -> None:
|
||||
) -> Iterable[str]:
|
||||
if isinstance(module, PPMissingLayer):
|
||||
return
|
||||
|
||||
@ -170,14 +172,16 @@ class AutoWeightsLoader:
|
||||
continue
|
||||
|
||||
if child_prefix in child_modules:
|
||||
self._load_module(prefix, child_modules[child_prefix],
|
||||
child_weights)
|
||||
yield from self._load_module(prefix,
|
||||
child_modules[child_prefix],
|
||||
child_weights)
|
||||
elif child_prefix in child_params:
|
||||
self._load_param(prefix, child_params[child_prefix],
|
||||
child_weights)
|
||||
yield from self._load_param(prefix, child_params[child_prefix],
|
||||
child_weights)
|
||||
else:
|
||||
if not self._can_ignore_unexpected(prefix):
|
||||
msg = f"There is no module or parameter named '{prefix}'"
|
||||
msg = (f"There is no module or parameter named '{prefix}' "
|
||||
f"in {type(self.module).__name__}")
|
||||
raise ValueError(msg)
|
||||
|
||||
def load_weights(
|
||||
@ -185,11 +189,12 @@ class AutoWeightsLoader:
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
*,
|
||||
mapper: Optional[WeightsMapper] = None,
|
||||
) -> None:
|
||||
) -> List[str]:
|
||||
if mapper is not None:
|
||||
weights = mapper.apply(weights)
|
||||
|
||||
self._load_module("", self.module, weights)
|
||||
autoloaded_weights = list(self._load_module("", self.module, weights))
|
||||
return autoloaded_weights
|
||||
|
||||
|
||||
def init_vllm_registered_model(
|
||||
|
Loading…
x
Reference in New Issue
Block a user