[Model] Replace embedding models with pooling adapter (#10769)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
7e4bbda573
commit
133707123e
@ -334,7 +334,6 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
|
||||
- pytest -v -s models/embedding/language -m core_model
|
||||
- pytest -v -s models/embedding/vision_language -m core_model
|
||||
|
||||
- label: Language Models Test (Extended) # 50min
|
||||
optional: true
|
||||
@ -346,7 +345,6 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/embedding/language -m 'not core_model'
|
||||
- pytest -v -s models/embedding/vision_language -m 'not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Standard) # 26min
|
||||
#mirror_hardwares: [amd]
|
||||
@ -359,6 +357,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
|
||||
- pytest -v -s models/embedding/vision_language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/vision_language -m core_model
|
||||
|
||||
@ -376,6 +375,7 @@ steps:
|
||||
# https://github.com/huggingface/transformers/issues/34307
|
||||
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/embedding/vision_language -m 'not core_model'
|
||||
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
|
||||
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
|
||||
|
||||
|
@ -357,7 +357,7 @@ Text Embedding
|
||||
- ✅︎
|
||||
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
|
||||
- Qwen2-based
|
||||
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
|
||||
- :code:`ssmits/Qwen2-7B-Instruct-embed-base` (see note), :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
|
||||
@ -378,6 +378,10 @@ Text Embedding
|
||||
.. tip::
|
||||
You can override the model's pooling method by passing :code:`--override-pooler-config`.
|
||||
|
||||
.. note::
|
||||
:code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
|
||||
You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`.
|
||||
|
||||
.. note::
|
||||
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
|
||||
You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
|
||||
@ -397,12 +401,21 @@ Reward Modeling
|
||||
- Example HF Models
|
||||
- :ref:`LoRA <lora>`
|
||||
- :ref:`PP <distributed_serving>`
|
||||
* - :code:`LlamaForCausalLM`
|
||||
- Llama-based
|
||||
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`Qwen2ForRewardModel`
|
||||
- Qwen2-based
|
||||
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
|
||||
.. important::
|
||||
For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
||||
e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
||||
|
||||
.. note::
|
||||
As an interim measure, these models are supported in both offline and online inference via Embeddings API.
|
||||
|
||||
|
@ -263,7 +263,6 @@ class HfRunner:
|
||||
dtype: str = "half",
|
||||
*,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
is_embedding_model: bool = False,
|
||||
is_sentence_transformer: bool = False,
|
||||
is_cross_encoder: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
|
@ -4,6 +4,8 @@ Run `pytest tests/models/embedding/language/test_embedding.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
|
||||
from ..utils import check_embeddings_close
|
||||
|
||||
|
||||
@ -33,6 +35,9 @@ def test_models(
|
||||
dtype: str,
|
||||
) -> None:
|
||||
vllm_extra_kwargs = {}
|
||||
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
|
||||
vllm_extra_kwargs["override_pooler_config"] = \
|
||||
PoolerConfig(pooling_type="MEAN")
|
||||
if model == "Alibaba-NLP/gte-Qwen2-7B-instruct":
|
||||
vllm_extra_kwargs["hf_overrides"] = {"is_causal": False}
|
||||
|
||||
|
@ -6,11 +6,8 @@ import torch.cuda
|
||||
from vllm.model_executor.models import (is_embedding_model,
|
||||
is_text_generation_model,
|
||||
supports_multimodal)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS,
|
||||
_EMBEDDING_MODELS,
|
||||
_MULTIMODAL_MODELS,
|
||||
from vllm.model_executor.models.adapters import as_embedding_model
|
||||
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
|
||||
_SPECULATIVE_DECODING_MODELS,
|
||||
_TEXT_GENERATION_MODELS,
|
||||
ModelRegistry)
|
||||
@ -26,18 +23,18 @@ def test_registry_imports(model_arch):
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)
|
||||
|
||||
if model_arch in _SPECULATIVE_DECODING_MODELS:
|
||||
pass # Ignore these models which do not have a unified format
|
||||
else:
|
||||
assert is_text_generation_model(model_cls) is (
|
||||
model_arch in _TEXT_GENERATION_MODELS
|
||||
or model_arch in _MULTIMODAL_MODELS)
|
||||
return # Ignore these models which do not have a unified format
|
||||
|
||||
embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS}
|
||||
assert is_embedding_model(model_cls) is (model_arch
|
||||
in embedding_models)
|
||||
if (model_arch in _TEXT_GENERATION_MODELS
|
||||
or model_arch in _MULTIMODAL_MODELS):
|
||||
assert is_text_generation_model(model_cls)
|
||||
|
||||
assert supports_multimodal(model_cls) is (model_arch
|
||||
in _MULTIMODAL_MODELS)
|
||||
# All vLLM models should be convertible to an embedding model
|
||||
embed_model = as_embedding_model(model_cls)
|
||||
assert is_embedding_model(embed_model)
|
||||
|
||||
if model_arch in _MULTIMODAL_MODELS:
|
||||
assert supports_multimodal(model_cls)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
|
@ -1,13 +1,34 @@
|
||||
from typing import List, Optional, Union
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.models.gemma2 import Gemma2Model
|
||||
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
|
||||
class MyGemma2Embedding(Gemma2EmbeddingModel):
|
||||
class MyGemma2Embedding(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.model = Gemma2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
vllm_config.model_config.pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False,
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -18,7 +39,7 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = super().forward(
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
@ -32,3 +53,17 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
|
||||
|
||||
# Return all-zero embeddings
|
||||
return torch.zeros_like(hidden_states)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
weights = hf_to_vllm_mapper.apply(weights)
|
||||
weights = ((name, data) for name, data in weights
|
||||
if not name.startswith("lm_head."))
|
||||
return self.model.load_weights(weights)
|
||||
|
@ -26,8 +26,7 @@ def test_auto_task(model_id, expected_task):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model_id", "bad_task"), [
|
||||
("facebook/opt-125m", "embedding"),
|
||||
("intfloat/e5-mistral-7b-instruct", "generate"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
|
||||
])
|
||||
def test_incorrect_task(model_id, bad_task):
|
||||
with pytest.raises(ValueError, match=r"does not support the .* task"):
|
||||
|
@ -370,6 +370,31 @@ class ModelConfig:
|
||||
selected_task = next(iter(supported_tasks_lst))
|
||||
|
||||
if len(supported_tasks) > 1:
|
||||
suffix_to_preferred_task: List[Tuple[str, _Task]] = [
|
||||
# Hardcode the models that are exceptions
|
||||
("AquilaModel", "generate"),
|
||||
("ChatGLMModel", "generate"),
|
||||
# Other models follow this pattern
|
||||
("ForCausalLM", "generate"),
|
||||
("ForConditionalGeneration", "generate"),
|
||||
("ChatModel", "generate"),
|
||||
("LMHeadModel", "generate"),
|
||||
("EmbeddingModel", "embedding"),
|
||||
("RewardModel", "embedding"),
|
||||
("ForSequenceClassification", "embedding"),
|
||||
]
|
||||
info, arch = ModelRegistry.inspect_model_cls(architectures)
|
||||
|
||||
for suffix, pref_task in suffix_to_preferred_task:
|
||||
if arch.endswith(suffix) and pref_task in supported_tasks:
|
||||
selected_task = pref_task
|
||||
break
|
||||
else:
|
||||
if (arch.endswith("Model")
|
||||
and info.architecture.endswith("ForCausalLM")
|
||||
and "embedding" in supported_tasks):
|
||||
selected_task = "embedding"
|
||||
|
||||
logger.info(
|
||||
"This model supports multiple tasks: %s. "
|
||||
"Defaulting to '%s'.", supported_tasks, selected_task)
|
||||
|
@ -11,8 +11,8 @@ from typing_extensions import TypeVar, assert_never
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
|
||||
resolve_mm_processor_kwargs)
|
||||
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
|
||||
print_warning_once, resolve_mm_processor_kwargs)
|
||||
|
||||
from .data import ProcessorInputs, SingletonInputs
|
||||
from .parse import is_encoder_decoder_inputs
|
||||
@ -136,12 +136,12 @@ class InputRegistry:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
|
||||
DummyDataFactory] = {}
|
||||
self._dummy_encoder_factories_by_model_type: Dict[
|
||||
Type[nn.Module], DummyDataFactory] = {}
|
||||
self._input_processors_by_model_type: Dict[Type[nn.Module],
|
||||
InputProcessor] = {}
|
||||
self._dummy_factories_by_model_type = \
|
||||
ClassRegistry[nn.Module, DummyDataFactory]()
|
||||
self._dummy_encoder_factories_by_model_type = \
|
||||
ClassRegistry[nn.Module, DummyDataFactory]()
|
||||
self._input_processors_by_model_type = \
|
||||
ClassRegistry[nn.Module, InputProcessor]()
|
||||
|
||||
def _default_dummy_data_factory(
|
||||
self,
|
||||
|
@ -60,9 +60,7 @@ class Pooler(nn.Module):
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[List[int]] = None,
|
||||
) -> Optional["Pooler"]:
|
||||
if pooler_config is None:
|
||||
return None
|
||||
) -> "Pooler":
|
||||
return cls(
|
||||
pooling_type=PoolingType[pooler_config.pooling_type]
|
||||
if pooler_config.pooling_type is not None else pooling_type,
|
||||
|
@ -9,6 +9,7 @@ import itertools
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
||||
@ -97,22 +98,31 @@ def device_loading_context(module: torch.nn.Module,
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
|
||||
def _initialize_model(
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
prefix: str = "",
|
||||
architectures: Optional[list[str]] = None,
|
||||
) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_config = vllm_config.model_config
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
model_class, _ = get_model_architecture(model_config,
|
||||
architectures=architectures)
|
||||
|
||||
signatures = inspect.signature(model_class.__init__)
|
||||
all_params = [param.name for param in signatures.parameters.values()]
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
# new-style model class
|
||||
with set_current_vllm_config(vllm_config):
|
||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
"input arguments. Possibly you have an old-style model class"
|
||||
" registered from out of tree and it is used for new vLLM version. "
|
||||
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
|
||||
"for the design and update the model class accordingly.")
|
||||
logger.warning(msg)
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
logger.warning(
|
||||
"Trying to guess the arguments for old-style model class %s",
|
||||
model_class,
|
||||
@ -356,7 +366,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(
|
||||
self._get_all_weights(model_config, model))
|
||||
# We only enable strict check for non-quantiized models
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
|
@ -1,12 +1,13 @@
|
||||
"""Utilities for selecting and loading models."""
|
||||
import contextlib
|
||||
from typing import Tuple, Type
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.adapters import as_embedding_model
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -19,8 +20,13 @@ def set_default_torch_dtype(dtype: torch.dtype):
|
||||
|
||||
|
||||
def get_model_architecture(
|
||||
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
||||
model_config: ModelConfig,
|
||||
*,
|
||||
architectures: Optional[list[str]] = None,
|
||||
) -> Tuple[Type[nn.Module], str]:
|
||||
if architectures is None:
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
|
||||
# Special handling for quantized Mixtral.
|
||||
# FIXME(woosuk): This is a temporary hack.
|
||||
mixtral_supported = [
|
||||
@ -32,7 +38,11 @@ def get_model_architecture(
|
||||
and "MixtralForCausalLM" in architectures):
|
||||
architectures = ["QuantMixtralForCausalLM"]
|
||||
|
||||
return ModelRegistry.resolve_model_cls(architectures)
|
||||
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
|
||||
if model_config.task == "embedding":
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
|
||||
return model_cls, arch
|
||||
|
||||
|
||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
|
98
vllm/model_executor/models/adapters.py
Normal file
98
vllm/model_executor/models/adapters.py
Normal file
@ -0,0 +1,98 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .interfaces_base import VllmModelForEmbedding, is_embedding_model
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
|
||||
def as_embedding_model(cls: _T) -> _T:
|
||||
"""Subclass an existing vLLM model to support embeddings."""
|
||||
# Avoid modifying existing embedding models
|
||||
if is_embedding_model(cls):
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput,
|
||||
PoolingType)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
|
||||
from .utils import AutoWeightsLoader, WeightsMapper
|
||||
|
||||
class ModelForEmbedding(cls, VllmModelForEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
# These are not used in embedding models
|
||||
for attr in ("lm_head", "logits_processor"):
|
||||
if hasattr(self, attr):
|
||||
delattr(self, attr)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
# If the model already defines a pooler instance, don't overwrite it
|
||||
if not getattr(self, "_pooler", None):
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False,
|
||||
)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
# TODO: Support uninitialized params tracking
|
||||
|
||||
# We have deleted this attribute, so don't load it
|
||||
weights = ((name, data) for name, data in weights
|
||||
if not name.startswith("lm_head."))
|
||||
|
||||
# If `*ForCausalLM` defines `load_weights` on the inner model
|
||||
# and there are no other inner modules with parameters,
|
||||
# we support loading from both `*Model` and `*ForCausalLM`
|
||||
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
|
||||
# Whether only `self.model` contains parameters
|
||||
model_is_only_param = all(
|
||||
name == "model" or next(child.parameters(), None) is None
|
||||
for name, child in self.named_children())
|
||||
|
||||
if model_is_only_param:
|
||||
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
weights = mapper.apply(weights)
|
||||
|
||||
self.model.load_weights(weights)
|
||||
return
|
||||
|
||||
# For most other models
|
||||
if hasattr(cls, "load_weights"):
|
||||
cls.load_weights(self, weights) # type: ignore
|
||||
# Fallback
|
||||
else:
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
||||
ModelForEmbedding.__name__ = cls.__name__ \
|
||||
.removesuffix("ForCausalLM") \
|
||||
.removesuffix("ForConditionalGeneration") \
|
||||
.removesuffix("ChatModel") \
|
||||
.removesuffix("LMHeadModel") + "ForEmbedding"
|
||||
|
||||
return ModelForEmbedding # type: ignore
|
@ -512,9 +512,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
@ -30,19 +30,17 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
|
||||
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
@ -455,55 +453,3 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
|
||||
"""
|
||||
A model that uses Gemma2 with additional embedding functionalities.
|
||||
|
||||
This class encapsulates the Gemma2Model and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
model: An instance of Gemma2Model used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.model = Gemma2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
vllm_config.model_config.pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
return self.model(input_ids, positions, kv_caches, attn_metadata,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
weights = hf_to_vllm_mapper.apply(weights)
|
||||
weights = ((name, data) for name, data in weights
|
||||
if not name.startswith("lm_head."))
|
||||
self.model.load_weights(weights)
|
||||
|
@ -474,9 +474,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.mlp1 = self._init_mlp1(config)
|
||||
|
||||
|
@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
get_compressed_tensors_cache_scale)
|
||||
@ -47,14 +46,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
extract_layer_index, is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@ -511,11 +509,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = self._init_model(vllm_config=vllm_config, prefix=prefix)
|
||||
self.model = self._init_model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
@ -544,13 +543,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.sampler = get_sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.STEP,
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
|
||||
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
|
||||
@ -581,14 +576,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
logits = self.compute_logits(hidden_states, None)
|
||||
return self._pooler(logits, pooling_metadata)
|
||||
|
||||
def sample(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
@ -639,78 +626,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
|
||||
class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"""
|
||||
A model that uses Llama with additional embedding functionalities.
|
||||
|
||||
This class encapsulates the LlamaModel and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
model: An instance of LlamaModel used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
return self.model(input_ids, positions, kv_caches, attn_metadata,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
weights = hf_to_vllm_mapper.apply(weights)
|
||||
weights = ((name, data) for name, data in weights
|
||||
if not name.startswith("lm_head."))
|
||||
self.model.load_weights(weights)
|
||||
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
||||
# LRUCacheWorkerLoRAManager instantiation requires model config.
|
||||
@property
|
||||
def config(self):
|
||||
return self.model.config
|
||||
|
@ -319,9 +319,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
projector_hidden_act=config.projector_hidden_act)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
@ -14,13 +14,11 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext)
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
@ -286,7 +284,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
vision_feature_layer = config.vision_feature_layer
|
||||
@ -321,17 +318,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
projector_hidden_act=config.projector_hidden_act)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
# The same model class supports both language generation and embedding
|
||||
# because the architecture name is the same
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -678,13 +669,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
@ -275,9 +275,10 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.model.make_empty_intermediate_tensors)
|
||||
|
@ -422,9 +422,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"))
|
||||
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
|
||||
|
@ -151,9 +151,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.quant_config = quant_config
|
||||
config.text_config.architectures = ["GemmaForCausalLM"]
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.language_model.logits_processor.scale *= logit_scale
|
||||
|
||||
|
@ -29,24 +29,22 @@ from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
|
||||
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -536,7 +534,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
@ -556,18 +553,17 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
# The prefix is empty intentionally because default prefix of
|
||||
# LlamaForCausalLM is "model"
|
||||
self.language_model = LlamaForCausalLM(vllm_config=vllm_config,
|
||||
prefix="")
|
||||
prefix="",
|
||||
# We don't directly initialize vLLM's LlamaForCausalLM so we
|
||||
# can automatically apply embedding wrapper if this model is
|
||||
# initialized as an embedding model
|
||||
architectures=["LlamaForCausalLM"],
|
||||
)
|
||||
|
||||
# The same model class supports both language generation and embedding
|
||||
# because the architecture name is the same
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -739,13 +735,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
|
@ -172,9 +172,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
# init MistralForCausalLM
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.vision_encoder = VisionTransformer(self.vision_args)
|
||||
self.vision_language_adapter = VisionLanguageAdapter(
|
||||
|
@ -31,6 +31,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@ -55,6 +56,8 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Qwen2MLP(nn.Module):
|
||||
|
||||
@ -433,7 +436,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
@ -454,14 +456,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
# The same model class supports both language generation and embedding
|
||||
# because the architecture name is the same
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -499,13 +493,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
@ -553,6 +540,15 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
# TODO: Replace this model class with for_embedding(Qwen2ForCausalLM),
|
||||
# after changing the default pooling method
|
||||
if pooler_config.pooling_type is None:
|
||||
logger.warning(
|
||||
"This embedding model will default to last-token pooling in "
|
||||
"an upcoming version. To avoid breaking changes, you should "
|
||||
"pass `--override-pooler-config '{\"pooling_type\": \"MEAN\"}'`"
|
||||
" explicitly.")
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.MEAN,
|
||||
|
@ -50,7 +50,6 @@ from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
@ -59,14 +58,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
|
||||
@ -1070,7 +1068,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Qwen2-VL currently does not support prefix caching"
|
||||
@ -1102,11 +1099,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
@ -1361,13 +1354,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
|
@ -20,6 +20,7 @@ import torch.nn as nn
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .adapters import as_embedding_model
|
||||
from .interfaces import (has_inner_state, is_attention_free,
|
||||
supports_cross_encoding, supports_multimodal,
|
||||
supports_pp)
|
||||
@ -107,15 +108,15 @@ _EMBEDDING_MODELS = {
|
||||
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
|
||||
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
||||
**{
|
||||
# Multiple models share the same architecture, so we include them all
|
||||
k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
|
||||
if arch == "LlamaForCausalLM"
|
||||
},
|
||||
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
||||
"MistralModel": ("llama", "LlamaForCausalLM"),
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
@ -125,7 +126,7 @@ _EMBEDDING_MODELS = {
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
@ -208,6 +209,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ModelInfo:
|
||||
architecture: str
|
||||
is_text_generation_model: bool
|
||||
is_embedding_model: bool
|
||||
supports_cross_encoding: bool
|
||||
@ -218,9 +220,19 @@ class _ModelInfo:
|
||||
|
||||
@staticmethod
|
||||
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
||||
is_embedding_model_ = is_embedding_model(model)
|
||||
if not is_embedding_model_:
|
||||
try:
|
||||
as_embedding_model(model)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
is_embedding_model_ = True
|
||||
|
||||
return _ModelInfo(
|
||||
architecture=model.__name__,
|
||||
is_text_generation_model=is_text_generation_model(model),
|
||||
is_embedding_model=is_embedding_model(model),
|
||||
is_embedding_model=is_embedding_model_,
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_pp=supports_pp(model),
|
||||
@ -399,13 +411,13 @@ class _ModelRegistry:
|
||||
def inspect_model_cls(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> _ModelInfo:
|
||||
) -> Tuple[_ModelInfo, str]:
|
||||
architectures = self._normalize_archs(architectures)
|
||||
|
||||
for arch in architectures:
|
||||
model_info = self._try_inspect_model_cls(arch)
|
||||
if model_info is not None:
|
||||
return model_info
|
||||
return (model_info, arch)
|
||||
|
||||
return self._raise_for_unsupported(architectures)
|
||||
|
||||
@ -426,39 +438,50 @@ class _ModelRegistry:
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).is_text_generation_model
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.is_text_generation_model
|
||||
|
||||
def is_embedding_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).is_embedding_model
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.is_embedding_model
|
||||
|
||||
def is_cross_encoder_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).supports_cross_encoding
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.supports_cross_encoding
|
||||
|
||||
def is_multimodal_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).supports_multimodal
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.supports_multimodal
|
||||
|
||||
def is_pp_supported_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).supports_pp
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.supports_pp
|
||||
|
||||
def model_has_inner_state(self, architectures: Union[str,
|
||||
List[str]]) -> bool:
|
||||
return self.inspect_model_cls(architectures).has_inner_state
|
||||
def model_has_inner_state(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.has_inner_state
|
||||
|
||||
def is_attention_free_model(self, architectures: Union[str,
|
||||
List[str]]) -> bool:
|
||||
return self.inspect_model_cls(architectures).is_attention_free
|
||||
def is_attention_free_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.is_attention_free
|
||||
|
||||
|
||||
ModelRegistry = _ModelRegistry({
|
||||
|
@ -360,9 +360,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
))
|
||||
self.multi_modal_projector = UltravoxProjector(config)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
if config.text_model_id is not None:
|
||||
# this prefix is not for initialization, but for loading weights
|
||||
# note the trailing dot
|
||||
|
@ -173,8 +173,15 @@ class AutoWeightsLoader:
|
||||
module_load_weights = getattr(module, "load_weights", None)
|
||||
if callable(module_load_weights):
|
||||
loaded_params = module_load_weights(weights)
|
||||
yield from map(lambda x: self._get_qualname(base_prefix, x),
|
||||
loaded_params)
|
||||
if loaded_params is None:
|
||||
logger.warning(
|
||||
"Unable to collect loaded parameters "
|
||||
"for module %s", module)
|
||||
else:
|
||||
yield from map(
|
||||
lambda x: self._get_qualname(base_prefix, x),
|
||||
loaded_params,
|
||||
)
|
||||
|
||||
child_modules = dict(module.named_children())
|
||||
child_params = dict(module.named_parameters(recurse=False))
|
||||
@ -232,17 +239,24 @@ class AutoWeightsLoader:
|
||||
|
||||
|
||||
def init_vllm_registered_model(
|
||||
hf_config: PretrainedConfig,
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
prefix: str = "",
|
||||
hf_config: Optional[PretrainedConfig] = None,
|
||||
architectures: Optional[list[str]] = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Helper function to initialize an inner model registered to vLLM,
|
||||
based on the arguments passed to the outer vLLM model.
|
||||
"""
|
||||
from vllm.model_executor.model_loader.loader import _initialize_model
|
||||
|
||||
if hf_config is not None:
|
||||
vllm_config = vllm_config.with_hf_config(hf_config)
|
||||
return _initialize_model(vllm_config, prefix)
|
||||
|
||||
return _initialize_model(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
architectures=architectures)
|
||||
|
||||
|
||||
@overload
|
||||
|
@ -7,7 +7,7 @@ from torch import nn
|
||||
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_allowed_kwarg_only_overrides,
|
||||
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
|
||||
resolve_mm_processor_kwargs)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -54,8 +54,8 @@ class MultiModalPlugin(ABC):
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
|
||||
self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
|
||||
self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]()
|
||||
self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]()
|
||||
|
||||
@abstractmethod
|
||||
def get_data_key(self) -> str:
|
||||
|
@ -9,6 +9,7 @@ from typing_extensions import TypeAlias
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import ClassRegistry
|
||||
|
||||
from .audio import AudioPlugin
|
||||
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
|
||||
@ -62,8 +63,8 @@ class MultiModalRegistry:
|
||||
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
|
||||
self._plugins = {p.get_data_key(): p for p in plugins}
|
||||
|
||||
self._processor_factories: Dict[Type[nn.Module],
|
||||
MultiModalProcessorFactory] = {}
|
||||
self._processor_factories = ClassRegistry[nn.Module,
|
||||
MultiModalProcessorFactory]()
|
||||
|
||||
# This is used for non-multimodal models
|
||||
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
|
||||
|
@ -20,7 +20,7 @@ import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
||||
from collections import defaultdict
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
@ -1517,13 +1517,13 @@ class AtomicCounter:
|
||||
|
||||
|
||||
# Adapted from: https://stackoverflow.com/a/47212782/5082708
|
||||
class LazyDict(Mapping, Generic[T]):
|
||||
class LazyDict(Mapping[str, T], Generic[T]):
|
||||
|
||||
def __init__(self, factory: Dict[str, Callable[[], T]]):
|
||||
self._factory = factory
|
||||
self._dict: Dict[str, T] = {}
|
||||
|
||||
def __getitem__(self, key) -> T:
|
||||
def __getitem__(self, key: str) -> T:
|
||||
if key not in self._dict:
|
||||
if key not in self._factory:
|
||||
raise KeyError(key)
|
||||
@ -1540,6 +1540,22 @@ class LazyDict(Mapping, Generic[T]):
|
||||
return len(self._factory)
|
||||
|
||||
|
||||
class ClassRegistry(UserDict[type[T], _V]):
|
||||
|
||||
def __getitem__(self, key: type[T]) -> _V:
|
||||
for cls in key.mro():
|
||||
if cls in self.data:
|
||||
return self.data[cls]
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
if not isinstance(key, type):
|
||||
return False
|
||||
|
||||
return any(cls in self.data for cls in key.mro())
|
||||
|
||||
|
||||
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Create a weak reference to a tensor.
|
||||
|
Loading…
x
Reference in New Issue
Block a user