[Model] VLM2Vec, the first multimodal embedding model in vLLM (#9303)

This commit is contained in:
Cyrus Leung 2024-10-16 14:31:00 +08:00 committed by GitHub
parent 7e7eae338d
commit 7abba39ee6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 465 additions and 261 deletions

View File

@ -3,7 +3,7 @@
Supported Models
================
vLLM supports a variety of generative Transformer models in `HuggingFace Transformers <https://huggingface.co/models>`_.
vLLM supports a variety of generative Transformer models in `HuggingFace (HF) Transformers <https://huggingface.co/models>`_.
The following is the list of model architectures that are currently supported by vLLM.
Alongside each architecture, we include some popular models that use it.
@ -19,7 +19,7 @@ Text Generation
* - Architecture
- Models
- Example HuggingFace Models
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`AquilaForCausalLM`
@ -280,7 +280,7 @@ Text Embedding
* - Architecture
- Models
- Example HuggingFace Models
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`Gemma2Model`
@ -303,7 +303,7 @@ Reward Modeling
* - Architecture
- Models
- Example HuggingFace Models
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`Qwen2ForRewardModel`
@ -316,7 +316,14 @@ Reward Modeling
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.
Multimodal Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^
The following modalities are supported depending on the model:
- **T**\ ext
- **I**\ mage
- **V**\ ideo
- **A**\ udio
.. _supported_vlms:
@ -324,78 +331,78 @@ Text Generation
---------------
.. list-table::
:widths: 25 25 25 25 5 5
:widths: 25 25 15 25 5 5
:header-rows: 1
* - Architecture
- Models
- Modalities
- Example HuggingFace Models
- Inputs
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`Blip2ForConditionalGeneration`
- BLIP-2
- Image\ :sup:`E`
- T + I\ :sup:`E`
- :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc.
-
- ✅︎
* - :code:`ChameleonForConditionalGeneration`
- Chameleon
- Image
- T + I
- :code:`facebook/chameleon-7b` etc.
-
- ✅︎
* - :code:`FuyuForCausalLM`
- Fuyu
- Image
- T + I
- :code:`adept/fuyu-8b` etc.
-
- ✅︎
* - :code:`ChatGLMModel`
- GLM-4V
- Image
- T + I
- :code:`THUDM/glm-4v-9b` etc.
-
- ✅︎
* - :code:`InternVLChatModel`
- InternVL2
- Image\ :sup:`E+`
- T + I\ :sup:`E+`
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
-
- ✅︎
* - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5
- Image\ :sup:`E+`
- T + I\ :sup:`E+`
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
-
- ✅︎
* - :code:`LlavaNextForConditionalGeneration`
- LLaVA-NeXT
- Image\ :sup:`E+`
- T + I\ :sup:`E+`
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
- ✅︎
* - :code:`LlavaNextVideoForConditionalGeneration`
- LLaVA-NeXT-Video
- Video
- T + V
- :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc.
-
- ✅︎
* - :code:`LlavaOnevisionForConditionalGeneration`
- LLaVA-Onevision
- Image\ :sup:`+` / Video
- T + I\ :sup:`+` + V
- :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc.
-
- ✅︎
* - :code:`MiniCPMV`
- MiniCPM-V
- Image\ :sup:`E+`
- T + I\ :sup:`E+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
- ✅︎
- ✅︎
* - :code:`MllamaForConditionalGeneration`
- Llama 3.2
- Image
- T + I
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
-
-
@ -407,43 +414,43 @@ Text Generation
- ✅︎
* - :code:`NVLM_D_Model`
- NVLM-D 1.0
- Image\ :sup:`E+`
- T + I\ :sup:`E+`
- :code:`nvidia/NVLM-D-72B`, etc.
-
- ✅︎
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- Image\ :sup:`E`
- T + I\ :sup:`E`
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
-
- ✅︎
* - :code:`Phi3VForCausalLM`
- Phi-3-Vision, Phi-3.5-Vision
- Image\ :sup:`E+`
- T + I\ :sup:`E+`
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
-
- ✅︎
* - :code:`PixtralForConditionalGeneration`
- Pixtral
- Image\ :sup:`+`
- T + I\ :sup:`+`
- :code:`mistralai/Pixtral-12B-2409`
-
- ✅︎
* - :code:`QWenLMHeadModel`
- Qwen-VL
- Image\ :sup:`E+`
- T + I\ :sup:`E+`
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
-
- ✅︎
* - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL
- Image\ :sup:`E+` / Video\ :sup:`+`
- T + I\ :sup:`E+` + V\ :sup:`+`
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
-
- ✅︎
* - :code:`UltravoxModel`
- Ultravox
- Audio\ :sup:`E+`
- T + A\ :sup:`E+`
- :code:`fixie-ai/ultravox-v0_3`
-
- ✅︎
@ -455,6 +462,26 @@ Text Generation
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
Multimodal Embedding
--------------------
.. list-table::
:widths: 25 25 15 25 5 5
:header-rows: 1
* - Architecture
- Models
- Inputs
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`Phi3VForCausalLM`
- Phi-3-Vision-based
- T + I
- :code:`TIGER-Lab/VLM2Vec-Full`
- 🚧
- ✅︎
----
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.

View File

@ -0,0 +1,21 @@
from vllm import LLM
from vllm.assets.image import ImageAsset
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
prompt = "<|image_1|> Represent the given image with the following question: What is in the image" # noqa: E501
# Create an LLM.
llm = LLM(
model="TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
mm_processor_kwargs={"num_crops": 16},
)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = llm.encode({"prompt": prompt, "multi_modal_data": {"image": image}})
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 3072 floats

View File

@ -262,7 +262,7 @@ class HfRunner:
dtype: str = "half",
*,
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_sentence_transformer: bool = False,
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
postprocess_inputs: Callable[[BatchEncoding],
BatchEncoding] = identity,
@ -271,7 +271,7 @@ class HfRunner:
self.model_name = model_name
if is_embedding_model:
if is_sentence_transformer:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = self.wrap_device(
@ -307,17 +307,23 @@ class HfRunner:
self.postprocess_inputs = postprocess_inputs
def generate(
def get_inputs(
self,
prompts: List[str],
images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
if images:
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[BatchEncoding]:
if images is not None:
assert len(prompts) == len(images)
outputs: List[Tuple[List[List[int]], List[str]]] = []
if videos is not None:
assert len(prompts) == len(videos)
if audios is not None:
assert len(prompts) == len(audios)
all_inputs: List[BatchEncoding] = []
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
@ -327,10 +333,33 @@ class HfRunner:
processor_kwargs["images"] = images[i]
if videos is not None and videos[i] is not None:
processor_kwargs["videos"] = videos[i]
if audios is not None and audios[i] is not None:
audio, sr = audios[i]
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
all_inputs.append(inputs)
return all_inputs
def generate(
self,
prompts: List[str],
images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
all_inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
outputs: List[Tuple[List[List[int]], List[str]]] = []
for inputs in all_inputs:
output_ids = self.model.generate(
**self.wrap_device(inputs, device=self.model.device.type),
use_cache=True,
@ -350,12 +379,16 @@ class HfRunner:
prompts: List[str],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens,
images=images,
videos=videos,
audios=audios,
**kwargs)
return [(output_ids[0], output_str[0])
@ -388,22 +421,16 @@ class HfRunner:
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[List[torch.Tensor]]:
all_inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
all_logprobs: List[List[torch.Tensor]] = []
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
if videos is not None and videos[i] is not None:
processor_kwargs["videos"] = videos[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
for inputs in all_inputs:
output = self.model.generate(
**self.wrap_device(inputs, device=self.model.device.type),
use_cache=True,
@ -475,28 +502,16 @@ class HfRunner:
videos: Optional[List[np.ndarray]] = None,
**kwargs: Any,
) -> List[TokensTextLogprobs]:
all_inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
if audios is not None:
audio, sr = audios[i]
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr
if videos is not None:
processor_kwargs["videos"] = videos[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
for inputs in all_inputs:
output = self.model.generate(
**self.wrap_device(inputs, device=self.model.device.type),
use_cache=True,
@ -632,20 +647,50 @@ class VllmRunner:
**kwargs,
)
def generate(
def get_inputs(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[TextPrompt]:
if images is not None:
assert len(prompts) == len(images)
if videos is not None:
assert len(prompts) == len(videos)
if audios is not None:
assert len(prompts) == len(audios)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image}
if videos is not None:
for i, video in enumerate(videos):
inputs[i]["multi_modal_data"] = {"video": video}
if audios is not None:
for i, audio in enumerate(audios):
inputs[i]["multi_modal_data"] = {"audio": audio}
return inputs
def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
@ -687,24 +732,10 @@ class VllmRunner:
videos: Optional[PromptVideoInput] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
if images is not None:
assert len(prompts) == len(images)
if videos is not None:
assert len(prompts) == len(videos)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image}
if audios is not None:
for i, audio in enumerate(audios):
inputs[i]["multi_modal_data"] = {"audio": audio}
if videos is not None:
for i, video in enumerate(videos):
inputs[i]["multi_modal_data"] = {"video": video}
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
@ -741,9 +772,15 @@ class VllmRunner:
prompts: List[str],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images)
outputs = self.generate(prompts,
greedy_params,
images=images,
videos=videos,
audios=audios)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

View File

@ -1,10 +1,10 @@
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
"""Compare the embedding outputs of HF and vLLM models.
Run `pytest tests/models/embedding/language/test_embedding.py`.
"""
import pytest
import torch
import torch.nn.functional as F
from ..utils import check_embeddings_close
MODELS = [
"intfloat/e5-mistral-7b-instruct",
@ -12,14 +12,6 @@ MODELS = [
]
def compare_embeddings(embeddings1, embeddings2):
similarities = [
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0)
for e1, e2 in zip(embeddings1, embeddings2)
]
return similarities
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
@ -37,15 +29,17 @@ def test_models(
# So we need to strip the input texts to avoid test failing.
example_prompts = [str(s).strip() for s in example_prompts]
with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model:
with hf_runner(model, dtype=dtype,
is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
similarities = compare_embeddings(hf_outputs, vllm_outputs)
all_similarities = torch.stack(similarities)
tolerance = 1e-2
assert torch.all((all_similarities <= 1.0 + tolerance)
& (all_similarities >= 1.0 - tolerance)
), f"Not all values are within {tolerance} of 1.0"
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)

View File

@ -0,0 +1,29 @@
from typing import List, Sequence
import torch
import torch.nn.functional as F
def check_embeddings_close(
*,
embeddings_0_lst: Sequence[List[float]],
embeddings_1_lst: Sequence[List[float]],
name_0: str,
name_1: str,
tol: float = 1e-3,
) -> None:
assert len(embeddings_0_lst) == len(embeddings_1_lst)
for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
zip(embeddings_0_lst, embeddings_1_lst)):
assert len(embeddings_0) == len(embeddings_1)
sim = F.cosine_similarity(torch.tensor(embeddings_0),
torch.tensor(embeddings_1),
dim=0)
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{embeddings_0!r}"
f"\n{name_1}:\t{embeddings_1!r}")
assert sim >= 1 - tol, fail_msg

View File

@ -0,0 +1,62 @@
import pytest
import torch.nn.functional as F
from ....conftest import IMAGE_ASSETS
from ..utils import check_embeddings_close
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign", # noqa: E501
"cherry_blossom":
"<|image_1|> Represent the given image with the following question: What is in the image", # noqa: E501
})
MODELS = ["TIGER-Lab/VLM2Vec-Full"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=2,
dtype=dtype,
enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(model, dtype=dtype) as hf_model:
all_inputs = hf_model.get_inputs(example_prompts)
all_outputs = []
for inputs in all_inputs:
# Based on: https://github.com/TIGER-AI-Lab/VLM2Vec/blob/db3b951bccabba220c1f53ab46a734e50dd2fc08/src/model.py
outputs = hf_model.model(
**hf_model.wrap_device(inputs,
device=hf_model.model.device.type),
return_dict=True,
output_hidden_states=True,
)
last_hidden_state = outputs.hidden_states[-1][0]
reps = last_hidden_state[inputs.attention_mask[0].sum() - 1]
pooled_output = F.normalize(reps, p=2, dim=-1)
all_outputs.append(pooled_output.tolist())
hf_outputs = all_outputs
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Union
import torch
from vllm.attention import AttentionMetadata
from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel
from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel
from vllm.sequence import IntermediateTensors

View File

@ -237,7 +237,16 @@ class ModelConfig:
def _verify_embedding_mode(self) -> None:
architectures = getattr(self.hf_config, "architectures", [])
self.embedding_mode = ModelRegistry.is_embedding_model(architectures)
# TODO: Allow the same model architecture to be specified as either
# generation or embedding model
if "Phi3VForCausalLM" in architectures:
# Match both remote and local names
embedding_mode = "/VLM2Vec" in self.model
else:
embedding_mode = ModelRegistry.is_embedding_model(architectures)
self.embedding_mode = embedding_mode
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)

View File

@ -31,14 +31,16 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
@ -461,3 +463,50 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None),
)
loader.load_weights(weights)
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = Gemma2Model(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)

View File

@ -1,57 +0,0 @@
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .gemma2 import Gemma2Model
from .interfaces import SupportsPP
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
"""A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = Gemma2Model(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)

View File

@ -38,6 +38,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
@ -47,8 +48,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP
@ -615,3 +617,52 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
name = name.replace(item, mapping[item])
return name, loaded_weight
class LlamaEmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)

View File

@ -1,59 +0,0 @@
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsPP
from .llama import LlamaModel
class LlamaEmbeddingModel(nn.Module, SupportsPP):
"""A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)

View File

@ -29,14 +29,18 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
@ -289,10 +293,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
dim=2).reshape(num_images, -1, hid_dim)
return image_features_hd_newline
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
@ -385,23 +385,28 @@ def dummy_data_for_phi3v(ctx: InputContext,
return seq_data, mm_data
# Reserve this function to also handle placeholders for additional images
# [ref: PR #5820]
@lru_cache
def _get_image_placeholder_token_ids(model_config: ModelConfig,
idx: int) -> List[int]:
def _get_image_placeholder_token_id_candidates(
model_config: ModelConfig,
idx: int,
) -> List[List[int]]:
assert idx > 0
tokenizer = cached_get_tokenizer(model_config.tokenizer)
# This is used when the image token is at the start of the string
start_candidate = tokenizer.encode(f"<|image_{idx}|>",
add_special_tokens=False)
# This is used when the image token is in the middle of the string
# We need to get the token for "<", not "▁<"
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
a_token_id, = tokenizer.encode("a", add_special_tokens=False)
a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
f"a<|image_{idx}|>", add_special_tokens=False)
a_token_id_, *middle_candidate = tokenizer.encode(f"a<|image_{idx}|>",
add_special_tokens=False)
assert a_token_id == a_token_id_
return image_placeholder_token_ids
return [start_candidate, middle_candidate]
def input_processor_for_phi3v(ctx: InputContext,
@ -461,16 +466,20 @@ def input_processor_for_phi3v(ctx: InputContext,
prompt_token_ids = llm_inputs["prompt_token_ids"].copy()
# masked place_holder with image token id
print("prompt_token_ids (old)", prompt_token_ids)
# masked placeholder with image token id
for idx in image_idx:
image_token_ids = _get_image_placeholder_token_ids(model_config,
idx=idx)
for i in range(len(prompt_token_ids) - len(image_token_ids) + 1):
if prompt_token_ids[i:i + len(image_token_ids)] == image_token_ids:
prompt_token_ids[i:i + len(image_token_ids)] = [
_IMAGE_TOKEN_ID
] * len(image_token_ids)
break
candidates = _get_image_placeholder_token_id_candidates(model_config,
idx=idx)
for candidate in candidates:
for i in range(len(prompt_token_ids) - len(candidate) + 1):
if prompt_token_ids[i:i + len(candidate)] == candidate:
prompt_token_ids[i:i +
len(candidate)] = ([_IMAGE_TOKEN_ID] *
len(candidate))
break
# merge consecutive tag ids
merged_token_ids: List[int] = []
@ -520,12 +529,23 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.multimodal_config = multimodal_config
self.image_token_id = _IMAGE_TOKEN_ID
# TODO: Optionally initializes this for supporting embeddings.
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
# TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(config)
self.language_model = LlamaForCausalLM(config, cache_config,
quant_config)
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@ -649,8 +669,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
@ -682,13 +701,27 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.vision_embed_tokens.wte": "embed_tokens",
"model.vision_embed_tokens.": "vision_embed_tokens.",
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
})
loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=hf_to_vllm_mapper)
autoloaded_weights = loader.load_weights(weights,
mapper=hf_to_vllm_mapper)
# The HF config doesn't specify whether these are tied,
# so we detect it this way
if "embed_tokens" not in autoloaded_weights:
self.embed_tokens = self.language_model.model.embed_tokens

View File

@ -86,9 +86,12 @@ _TEXT_GENERATION_MODELS = {
}
_EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
# [Text-only]
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"),
# [Multimodal]
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
}
_MULTIMODAL_MODELS = {

View File

@ -124,7 +124,7 @@ class AutoWeightsLoader:
base_prefix: str,
param: nn.Parameter,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> None:
) -> Iterable[str]:
for weight_name, weight_data in weights:
weight_qualname = self._get_qualname(base_prefix, weight_name)
@ -143,12 +143,14 @@ class AutoWeightsLoader:
default_weight_loader)
weight_loader(param, weight_data)
yield weight_qualname
def _load_module(
self,
base_prefix: str,
module: nn.Module,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> None:
) -> Iterable[str]:
if isinstance(module, PPMissingLayer):
return
@ -170,14 +172,16 @@ class AutoWeightsLoader:
continue
if child_prefix in child_modules:
self._load_module(prefix, child_modules[child_prefix],
child_weights)
yield from self._load_module(prefix,
child_modules[child_prefix],
child_weights)
elif child_prefix in child_params:
self._load_param(prefix, child_params[child_prefix],
child_weights)
yield from self._load_param(prefix, child_params[child_prefix],
child_weights)
else:
if not self._can_ignore_unexpected(prefix):
msg = f"There is no module or parameter named '{prefix}'"
msg = (f"There is no module or parameter named '{prefix}' "
f"in {type(self.module).__name__}")
raise ValueError(msg)
def load_weights(
@ -185,11 +189,12 @@ class AutoWeightsLoader:
weights: Iterable[Tuple[str, torch.Tensor]],
*,
mapper: Optional[WeightsMapper] = None,
) -> None:
) -> List[str]:
if mapper is not None:
weights = mapper.apply(weights)
self._load_module("", self.module, weights)
autoloaded_weights = list(self._load_module("", self.module, weights))
return autoloaded_weights
def init_vllm_registered_model(