[Model] Adding Support for Qwen2VL as an Embedding Model. Using MrLight/dse-qwen2-2b-mrl-v1 (#9944)
Signed-off-by: FurtherAI <austin.veselka@lighton.ai> Co-authored-by: FurtherAI <austin.veselka@lighton.ai>
This commit is contained in:
parent
3945c82346
commit
1b886aa104
@ -584,6 +584,12 @@ Multimodal Embedding
|
|||||||
- :code:`TIGER-Lab/VLM2Vec-Full`
|
- :code:`TIGER-Lab/VLM2Vec-Full`
|
||||||
- 🚧
|
- 🚧
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
* - :code:`Qwen2VLForConditionalGeneration`
|
||||||
|
- Qwen2-VL-based
|
||||||
|
- T + I
|
||||||
|
- :code:`MrLight/dse-qwen2-2b-mrl-v1`
|
||||||
|
-
|
||||||
|
- ✅︎
|
||||||
|
|
||||||
.. important::
|
.. important::
|
||||||
Some model architectures support both generation and embedding tasks.
|
Some model architectures support both generation and embedding tasks.
|
||||||
|
@ -310,4 +310,21 @@ Since the request schema is not defined by OpenAI client, we post a request to t
|
|||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
print("Embedding output:", response_json["data"][0]["embedding"])
|
print("Embedding output:", response_json["data"][0]["embedding"])
|
||||||
|
|
||||||
|
Here is an example for serving the ``MrLight/dse-qwen2-2b-mrl-v1`` model.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embedding \
|
||||||
|
--trust-remote-code --max-model-len 8192 --chat-template examples/template_dse_qwen2_vl.jinja
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
|
||||||
|
Like with VLM2Vec, we have to explicitly pass ``--task embedding``. Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings,
|
||||||
|
which is handled by the jinja template.
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
|
||||||
|
Also important, ``MrLight/dse-qwen2-2b-mrl-v1`` requires a placeholder image of the minimum image size for text query embeddings. See the full code
|
||||||
|
example below for details.
|
||||||
|
|
||||||
A full code example can be found in `examples/openai_chat_embedding_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_embedding_client_for_multimodal.py>`_.
|
A full code example can be found in `examples/openai_chat_embedding_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_embedding_client_for_multimodal.py>`_.
|
||||||
|
@ -1,8 +1,15 @@
|
|||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
|
def vlm2vec():
|
||||||
|
response = requests.post(
|
||||||
"http://localhost:8000/v1/embeddings",
|
"http://localhost:8000/v1/embeddings",
|
||||||
json={
|
json={
|
||||||
"model":
|
"model":
|
||||||
@ -26,8 +33,88 @@ response = requests.post(
|
|||||||
"encoding_format":
|
"encoding_format":
|
||||||
"float",
|
"float",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
print("Embedding output:", response_json["data"][0]["embedding"])
|
print("Embedding output:", response_json["data"][0]["embedding"])
|
||||||
|
|
||||||
|
|
||||||
|
def dse_qwen2_vl(inp: dict):
|
||||||
|
# Embedding an Image
|
||||||
|
if inp["dtype"] == "image":
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": inp["image_url"],
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
"type": "text",
|
||||||
|
"text": "What is shown in this image?"
|
||||||
|
}]
|
||||||
|
}]
|
||||||
|
# Embedding a Text Query
|
||||||
|
else:
|
||||||
|
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
|
||||||
|
# of the minimum input size
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
image_placeholder = Image.new("RGB", (56, 56))
|
||||||
|
image_placeholder.save(buffer, "png")
|
||||||
|
buffer.seek(0)
|
||||||
|
image_placeholder = base64.b64encode(buffer.read()).decode('utf-8')
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{image_placeholder}",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"Query: {inp['content']}"
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}]
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:8000/v1/embeddings",
|
||||||
|
json={
|
||||||
|
"model": "MrLight/dse-qwen2-2b-mrl-v1",
|
||||||
|
"messages": messages,
|
||||||
|
"encoding_format": "float",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
print("Embedding output:", response_json["data"][0]["embedding"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
"Script to call a specified VLM through the API. Make sure to serve "
|
||||||
|
"the model with --task embedding before running this.")
|
||||||
|
parser.add_argument("model",
|
||||||
|
type=str,
|
||||||
|
choices=["vlm2vec", "dse_qwen2_vl"],
|
||||||
|
required=True,
|
||||||
|
help="Which model to call.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.model == "vlm2vec":
|
||||||
|
vlm2vec()
|
||||||
|
elif args.model == "dse_qwen2_vl":
|
||||||
|
dse_qwen2_vl({
|
||||||
|
"dtye": "image",
|
||||||
|
"image_url": image_url,
|
||||||
|
})
|
||||||
|
dse_qwen2_vl({
|
||||||
|
"dtype": "text",
|
||||||
|
"content": "What is the weather like today?",
|
||||||
|
})
|
||||||
|
7
examples/template_dse_qwen2_vl.jinja
Normal file
7
examples/template_dse_qwen2_vl.jinja
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{% raw %}<|im_start|>system
|
||||||
|
You are a helpful assistant.<|im_end|>
|
||||||
|
{% endraw %}{% endif %}<|im_start|>{{ message['role'] }}{% raw %}
|
||||||
|
{% endraw %}{% if message['content'] is string %}{{ message['content'] }}<|im_end|>{% raw %}
|
||||||
|
{% endraw %}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>{% raw %}
|
||||||
|
{% endraw %}{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant{% raw %}
|
||||||
|
{% endraw %}{% endif %}<|endoftext|>
|
@ -243,6 +243,9 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
|||||||
class HfRunner:
|
class HfRunner:
|
||||||
|
|
||||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||||
|
if x is None or isinstance(x, (bool, )):
|
||||||
|
return x
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = "cpu" if current_platform.is_cpu() else "cuda"
|
device = "cpu" if current_platform.is_cpu() else "cuda"
|
||||||
|
|
||||||
|
209
tests/models/embedding/vision_language/test_dse_qwen2_vl.py
Normal file
209
tests/models/embedding/vision_language/test_dse_qwen2_vl.py
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
from functools import partial
|
||||||
|
from typing import Callable, Dict, List, Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import BatchEncoding, Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
|
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||||
|
from ....utils import large_gpu_test
|
||||||
|
from ..utils import check_embeddings_close
|
||||||
|
|
||||||
|
HF_TEXT_PROMPTS = [
|
||||||
|
# T -> X
|
||||||
|
(
|
||||||
|
"Query: Find me an everyday image that matches the given caption: The label of the object is stop sign", # noqa: E501,
|
||||||
|
Image.new("RGB", (56, 56))),
|
||||||
|
# T -> X
|
||||||
|
("Query: Retrieve an image of this caption: cherry blossom",
|
||||||
|
Image.new("RGB", (56, 56))),
|
||||||
|
]
|
||||||
|
|
||||||
|
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||||
|
"stop_sign":
|
||||||
|
"What is shown in this image?",
|
||||||
|
"cherry_blossom":
|
||||||
|
"What is shown in this image?"
|
||||||
|
})
|
||||||
|
|
||||||
|
MODELS = ["MrLight/dse-qwen2-2b-mrl-v1"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_messages(image: Image.Image, text: str, embed_text: bool):
|
||||||
|
# assert False, 'remember to use outer [] as required'
|
||||||
|
if embed_text:
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": Image.new("RGB", (56, 56)),
|
||||||
|
"resized_height": 1,
|
||||||
|
"resized_width": 1
|
||||||
|
}, # need a dummy image here for an easier process.
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": text
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}]
|
||||||
|
else:
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [{
|
||||||
|
"type": "image",
|
||||||
|
"image": image
|
||||||
|
}, {
|
||||||
|
"type": "text",
|
||||||
|
"text": text
|
||||||
|
}]
|
||||||
|
}]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def apply_chat_template_and_add_eos(
|
||||||
|
messages: List[Dict],
|
||||||
|
apply_chat_template_fn: Callable,
|
||||||
|
):
|
||||||
|
prompt = apply_chat_template_fn(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True) + "<|endoftext|>"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_inputs(hf_model: HfRunner, inputs: BatchEncoding, **kwargs):
|
||||||
|
return hf_model.model.prepare_inputs_for_generation(**inputs, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_test(
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
input_texts: List[str],
|
||||||
|
input_images: PromptImageInput,
|
||||||
|
embed_texts: List[bool],
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
'''SET PYTHONPATH'''
|
||||||
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
with vllm_runner(model,
|
||||||
|
task="embedding",
|
||||||
|
dtype=dtype,
|
||||||
|
enforce_eager=True,
|
||||||
|
max_model_len=8192) as vllm_model:
|
||||||
|
tokenizer = vllm_model.model.get_tokenizer()
|
||||||
|
texts = [
|
||||||
|
# this is necessary because vllm_model.encode will not apply any
|
||||||
|
# templating to the prompt, and therefore lacks an image_pad
|
||||||
|
# token unless one is inserted beforehand (the (28,28) image
|
||||||
|
# above is converted to an image pad token by the chat template).
|
||||||
|
apply_chat_template_and_add_eos(
|
||||||
|
get_messages(image, text, False),
|
||||||
|
apply_chat_template_fn=tokenizer.apply_chat_template,
|
||||||
|
) for text, image in zip(input_texts, input_images)
|
||||||
|
# vllm will replace the pad token with the actual image,
|
||||||
|
# which may be a placeholder image, later.
|
||||||
|
]
|
||||||
|
vllm_outputs = vllm_model.encode(texts, images=input_images)
|
||||||
|
|
||||||
|
hf_outputs = []
|
||||||
|
with hf_runner(model,
|
||||||
|
dtype=dtype,
|
||||||
|
auto_cls=Qwen2VLForConditionalGeneration) as hf_model:
|
||||||
|
hf_model.postprocess_inputs = partial(
|
||||||
|
postprocess_inputs,
|
||||||
|
hf_model,
|
||||||
|
cache_position=torch.arange(
|
||||||
|
0,
|
||||||
|
1, # 1 for batch size
|
||||||
|
requires_grad=False),
|
||||||
|
use_cache=False)
|
||||||
|
for text, image, embed_text in zip(input_texts, input_images,
|
||||||
|
embed_texts):
|
||||||
|
# dse requires non-standard input processing
|
||||||
|
# because it needs an image_pad token
|
||||||
|
messages = get_messages(image, text, embed_text)
|
||||||
|
prompt = apply_chat_template_and_add_eos(
|
||||||
|
messages, hf_model.processor.apply_chat_template)
|
||||||
|
inputs = hf_model.get_inputs(
|
||||||
|
prompts=[[prompt]],
|
||||||
|
images=[[image]],
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = hf_model.model(
|
||||||
|
**hf_model.wrap_device(inputs[0],
|
||||||
|
device=hf_model.model.device.type),
|
||||||
|
return_dict=True,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
pooled_output = torch.nn.functional.normalize(
|
||||||
|
outputs.hidden_states[-1][0, -1], p=2, dim=-1)
|
||||||
|
hf_outputs.append(pooled_output.tolist())
|
||||||
|
|
||||||
|
check_embeddings_close(
|
||||||
|
embeddings_0_lst=hf_outputs,
|
||||||
|
embeddings_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
|
def test_models_text(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
image_assets,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
input_texts_images = [(text, image_placeholder)
|
||||||
|
for text, image_placeholder in HF_TEXT_PROMPTS]
|
||||||
|
input_texts = [text for text, _ in input_texts_images]
|
||||||
|
input_images = [image for _, image in input_texts_images]
|
||||||
|
embed_texts = [True] * len(input_texts)
|
||||||
|
|
||||||
|
_run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
input_texts,
|
||||||
|
input_images, # type: ignore
|
||||||
|
embed_texts,
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@large_gpu_test(min_gb=48)
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
|
def test_models_image(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
image_assets,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
input_texts_images = [
|
||||||
|
(text, asset.pil_image)
|
||||||
|
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
|
||||||
|
]
|
||||||
|
input_texts = [text for text, _ in input_texts_images]
|
||||||
|
input_images = [image for _, image in input_texts_images]
|
||||||
|
embed_texts = [False] * len(input_texts)
|
||||||
|
|
||||||
|
_run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
input_texts,
|
||||||
|
input_images,
|
||||||
|
embed_texts,
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
@ -51,6 +51,7 @@ from vllm.model_executor.layers.activation import QuickGELU
|
|||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||||
from vllm.model_executor.layers.quantization import (GPTQConfig,
|
from vllm.model_executor.layers.quantization import (GPTQConfig,
|
||||||
GPTQMarlinConfig,
|
GPTQMarlinConfig,
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
@ -58,12 +59,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||||
MultiModalKwargs)
|
MultiModalKwargs)
|
||||||
from vllm.multimodal.base import MultiModalData
|
from vllm.multimodal.base import MultiModalData
|
||||||
from vllm.multimodal.image import cached_get_image_processor
|
from vllm.multimodal.image import cached_get_image_processor
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
from vllm.sequence import IntermediateTensors, SequenceData
|
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
|
||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
from vllm.transformers_utils.processor import cached_get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
|
|
||||||
@ -1067,6 +1069,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
multimodal_config = vllm_config.model_config.multimodal_config
|
multimodal_config = vllm_config.model_config.multimodal_config
|
||||||
assert not cache_config.enable_prefix_caching, \
|
assert not cache_config.enable_prefix_caching, \
|
||||||
"Qwen2-VL currently does not support prefix caching"
|
"Qwen2-VL currently does not support prefix caching"
|
||||||
@ -1098,6 +1101,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = get_sampler()
|
self.sampler = get_sampler()
|
||||||
|
self._pooler = Pooler.from_config_with_defaults(
|
||||||
|
pooler_config,
|
||||||
|
pooling_type=PoolingType.LAST,
|
||||||
|
normalize=True,
|
||||||
|
softmax=False)
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
make_empty_intermediate_tensors_factory(
|
make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], config.hidden_size))
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
@ -1318,6 +1326,13 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
return self._pooler(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
|
@ -109,6 +109,7 @@ _EMBEDDING_MODELS = {
|
|||||||
# [Multimodal]
|
# [Multimodal]
|
||||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||||
|
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
|
||||||
}
|
}
|
||||||
|
|
||||||
_MULTIMODAL_MODELS = {
|
_MULTIMODAL_MODELS = {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user