[Model] Replace embedding models with pooling adapter (#10769)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-01 08:02:54 +08:00 committed by GitHub
parent 7e4bbda573
commit 133707123e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 383 additions and 319 deletions

View File

@ -334,7 +334,6 @@ steps:
commands:
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
- pytest -v -s models/embedding/language -m core_model
- pytest -v -s models/embedding/vision_language -m core_model
- label: Language Models Test (Extended) # 50min
optional: true
@ -346,7 +345,6 @@ steps:
commands:
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model'
- pytest -v -s models/embedding/vision_language -m 'not core_model'
- label: Multi-Modal Models Test (Standard) # 26min
#mirror_hardwares: [amd]
@ -359,6 +357,7 @@ steps:
commands:
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model
- pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model
@ -376,6 +375,7 @@ steps:
# https://github.com/huggingface/transformers/issues/34307
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/vision_language -m 'not core_model'
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'

View File

@ -357,7 +357,7 @@ Text Embedding
- ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- :code:`ssmits/Qwen2-7B-Instruct-embed-base` (see note), :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- ✅︎
- ✅︎
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
@ -378,6 +378,10 @@ Text Embedding
.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.
.. note::
:code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`.
.. note::
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
@ -397,12 +401,21 @@ Reward Modeling
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`LlamaForCausalLM`
- Llama-based
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
- ✅︎
- ✅︎
* - :code:`Qwen2ForRewardModel`
- Qwen2-based
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
- ✅︎
- ✅︎
.. important::
For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
.. note::
As an interim measure, these models are supported in both offline and online inference via Embeddings API.

View File

@ -263,7 +263,6 @@ class HfRunner:
dtype: str = "half",
*,
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_sentence_transformer: bool = False,
is_cross_encoder: bool = False,
skip_tokenizer_init: bool = False,

View File

@ -4,6 +4,8 @@ Run `pytest tests/models/embedding/language/test_embedding.py`.
"""
import pytest
from vllm.config import PoolerConfig
from ..utils import check_embeddings_close
@ -33,6 +35,9 @@ def test_models(
dtype: str,
) -> None:
vllm_extra_kwargs = {}
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
vllm_extra_kwargs["override_pooler_config"] = \
PoolerConfig(pooling_type="MEAN")
if model == "Alibaba-NLP/gte-Qwen2-7B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": False}

View File

@ -6,11 +6,8 @@ import torch.cuda
from vllm.model_executor.models import (is_embedding_model,
is_text_generation_model,
supports_multimodal)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS,
_EMBEDDING_MODELS,
_MULTIMODAL_MODELS,
from vllm.model_executor.models.adapters import as_embedding_model
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS,
ModelRegistry)
@ -26,18 +23,18 @@ def test_registry_imports(model_arch):
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)
if model_arch in _SPECULATIVE_DECODING_MODELS:
pass # Ignore these models which do not have a unified format
else:
assert is_text_generation_model(model_cls) is (
model_arch in _TEXT_GENERATION_MODELS
or model_arch in _MULTIMODAL_MODELS)
return # Ignore these models which do not have a unified format
embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS}
assert is_embedding_model(model_cls) is (model_arch
in embedding_models)
if (model_arch in _TEXT_GENERATION_MODELS
or model_arch in _MULTIMODAL_MODELS):
assert is_text_generation_model(model_cls)
assert supports_multimodal(model_cls) is (model_arch
in _MULTIMODAL_MODELS)
# All vLLM models should be convertible to an embedding model
embed_model = as_embedding_model(model_cls)
assert is_embedding_model(embed_model)
if model_arch in _MULTIMODAL_MODELS:
assert supports_multimodal(model_cls)
@fork_new_process_for_each_test

View File

@ -1,13 +1,34 @@
from typing import List, Optional, Union
from typing import Iterable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm.attention import AttentionMetadata
from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel
from vllm.sequence import IntermediateTensors
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
class MyGemma2Embedding(Gemma2EmbeddingModel):
class MyGemma2Embedding(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False,
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
@ -18,7 +39,7 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = super().forward(
hidden_states = self.model(
input_ids,
positions,
kv_caches,
@ -32,3 +53,17 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
# Return all-zero embeddings
return torch.zeros_like(hidden_states)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
return self.model.load_weights(weights)

View File

@ -26,8 +26,7 @@ def test_auto_task(model_id, expected_task):
@pytest.mark.parametrize(("model_id", "bad_task"), [
("facebook/opt-125m", "embedding"),
("intfloat/e5-mistral-7b-instruct", "generate"),
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
])
def test_incorrect_task(model_id, bad_task):
with pytest.raises(ValueError, match=r"does not support the .* task"):

View File

@ -370,6 +370,31 @@ class ModelConfig:
selected_task = next(iter(supported_tasks_lst))
if len(supported_tasks) > 1:
suffix_to_preferred_task: List[Tuple[str, _Task]] = [
# Hardcode the models that are exceptions
("AquilaModel", "generate"),
("ChatGLMModel", "generate"),
# Other models follow this pattern
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("EmbeddingModel", "embedding"),
("RewardModel", "embedding"),
("ForSequenceClassification", "embedding"),
]
info, arch = ModelRegistry.inspect_model_cls(architectures)
for suffix, pref_task in suffix_to_preferred_task:
if arch.endswith(suffix) and pref_task in supported_tasks:
selected_task = pref_task
break
else:
if (arch.endswith("Model")
and info.architecture.endswith("ForCausalLM")
and "embedding" in supported_tasks):
selected_task = "embedding"
logger.info(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task)

View File

@ -11,8 +11,8 @@ from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
resolve_mm_processor_kwargs)
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
print_warning_once, resolve_mm_processor_kwargs)
from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
@ -136,12 +136,12 @@ class InputRegistry:
"""
def __init__(self) -> None:
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
DummyDataFactory] = {}
self._dummy_encoder_factories_by_model_type: Dict[
Type[nn.Module], DummyDataFactory] = {}
self._input_processors_by_model_type: Dict[Type[nn.Module],
InputProcessor] = {}
self._dummy_factories_by_model_type = \
ClassRegistry[nn.Module, DummyDataFactory]()
self._dummy_encoder_factories_by_model_type = \
ClassRegistry[nn.Module, DummyDataFactory]()
self._input_processors_by_model_type = \
ClassRegistry[nn.Module, InputProcessor]()
def _default_dummy_data_factory(
self,

View File

@ -60,9 +60,7 @@ class Pooler(nn.Module):
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
) -> Optional["Pooler"]:
if pooler_config is None:
return None
) -> "Pooler":
return cls(
pooling_type=PoolingType[pooler_config.pooling_type]
if pooler_config.pooling_type is not None else pooling_type,

View File

@ -9,6 +9,7 @@ import itertools
import json
import math
import os
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
@ -97,22 +98,31 @@ def device_loading_context(module: torch.nn.Module,
logger = init_logger(__name__)
def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
def _initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
architectures: Optional[list[str]] = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config)
model_class, _ = get_model_architecture(model_config,
architectures=architectures)
signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
with set_current_vllm_config(vllm_config):
return model_class(vllm_config=vllm_config, prefix=prefix)
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly.")
logger.warning(msg)
warnings.warn(msg, DeprecationWarning, stacklevel=2)
logger.warning(
"Trying to guess the arguments for old-style model class %s",
model_class,
@ -356,7 +366,7 @@ class DefaultModelLoader(BaseModelLoader):
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self._get_all_weights(model_config, model))
# We only enable strict check for non-quantiized models
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights

View File

@ -1,12 +1,13 @@
"""Utilities for selecting and loading models."""
import contextlib
from typing import Tuple, Type
from typing import Optional, Tuple, Type
import torch
from torch import nn
from vllm.config import ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import as_embedding_model
@contextlib.contextmanager
@ -19,8 +20,13 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
model_config: ModelConfig,
*,
architectures: Optional[list[str]] = None,
) -> Tuple[Type[nn.Module], str]:
if architectures is None:
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = [
@ -32,7 +38,11 @@ def get_model_architecture(
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
return ModelRegistry.resolve_model_cls(architectures)
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embedding":
model_cls = as_embedding_model(model_cls)
return model_cls, arch
def get_architecture_class_name(model_config: ModelConfig) -> str:

View File

@ -0,0 +1,98 @@
from collections.abc import Iterable
from typing import Any, TypeVar
import torch
import torch.nn as nn
from .interfaces_base import VllmModelForEmbedding, is_embedding_model
_T = TypeVar("_T", bound=type[nn.Module])
def as_embedding_model(cls: _T) -> _T:
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
if is_embedding_model(cls):
return cls
# Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput,
PoolingType)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from .utils import AutoWeightsLoader, WeightsMapper
class ModelForEmbedding(cls, VllmModelForEmbedding):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# These are not used in embedding models
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False,
)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: Support uninitialized params tracking
# We have deleted this attribute, so don't load it
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
# we support loading from both `*Model` and `*ForCausalLM`
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
# Whether only `self.model` contains parameters
model_is_only_param = all(
name == "model" or next(child.parameters(), None) is None
for name, child in self.named_children())
if model_is_only_param:
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = mapper.apply(weights)
self.model.load_weights(weights)
return
# For most other models
if hasattr(cls, "load_weights"):
cls.load_weights(self, weights) # type: ignore
# Fallback
else:
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
ModelForEmbedding.__name__ = cls.__name__ \
.removesuffix("ForCausalLM") \
.removesuffix("ForConditionalGeneration") \
.removesuffix("ChatModel") \
.removesuffix("LMHeadModel") + "ForEmbedding"
return ModelForEmbedding # type: ignore

View File

@ -512,9 +512,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

View File

@ -30,19 +30,17 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -455,55 +453,3 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)

View File

@ -474,9 +474,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.mlp1 = self._init_mlp1(config)

View File

@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
@ -47,14 +46,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
extract_layer_index, is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -511,11 +509,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config
self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config, prefix=prefix)
self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
@ -544,13 +543,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.STEP,
normalize=False,
softmax=False)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
@ -581,14 +576,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
logits = self.compute_logits(hidden_states, None)
return self._pooler(logits, pooling_metadata)
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
@ -639,78 +626,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
name = name.replace(item, mapping[item])
return name, loaded_weight
class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.model = LlamaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
# LRUCacheWorkerLoRAManager instantiation requires model config.
@property
def config(self):
return self.model.config

View File

@ -319,9 +319,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

View File

@ -14,13 +14,11 @@ from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip,
@ -286,7 +284,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
vision_feature_layer = config.vision_feature_layer
@ -321,17 +318,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@ -678,13 +669,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)

View File

@ -275,9 +275,10 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)

View File

@ -422,9 +422,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))

View File

@ -151,9 +151,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self.quant_config = quant_config
config.text_config.architectures = ["GemmaForCausalLM"]
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale

View File

@ -29,24 +29,22 @@ from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__)
@ -536,7 +534,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
@ -556,18 +553,17 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
quant_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
# The prefix is empty intentionally because default prefix of
# LlamaForCausalLM is "model"
self.language_model = LlamaForCausalLM(vllm_config=vllm_config,
prefix="")
prefix="",
# We don't directly initialize vLLM's LlamaForCausalLM so we
# can automatically apply embedding wrapper if this model is
# initialized as an embedding model
architectures=["LlamaForCausalLM"],
)
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@ -739,13 +735,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
hf_to_vllm_mapper = WeightsMapper(

View File

@ -172,9 +172,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# init MistralForCausalLM
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_encoder = VisionTransformer(self.vision_args)
self.vision_language_adapter = VisionLanguageAdapter(

View File

@ -31,6 +31,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -55,6 +56,8 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
class Qwen2MLP(nn.Module):
@ -433,7 +436,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config
self.lora_config = lora_config
@ -454,14 +456,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@ -499,13 +493,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
@ -553,6 +540,15 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
# TODO: Replace this model class with for_embedding(Qwen2ForCausalLM),
# after changing the default pooling method
if pooler_config.pooling_type is None:
logger.warning(
"This embedding model will default to last-token pooling in "
"an upcoming version. To avoid breaking changes, you should "
"pass `--override-pooler-config '{\"pooling_type\": \"MEAN\"}'`"
" explicitly.")
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.MEAN,

View File

@ -50,7 +50,6 @@ from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
@ -59,14 +58,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor
@ -1070,7 +1068,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
assert not cache_config.enable_prefix_caching, \
"Qwen2-VL currently does not support prefix caching"
@ -1102,11 +1099,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
@ -1361,13 +1354,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [

View File

@ -20,6 +20,7 @@ import torch.nn as nn
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .adapters import as_embedding_model
from .interfaces import (has_inner_state, is_attention_free,
supports_cross_encoding, supports_multimodal,
supports_pp)
@ -107,15 +108,15 @@ _EMBEDDING_MODELS = {
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
"LlamaModel": ("llama", "LlamaForCausalLM"),
**{
# Multiple models share the same architecture, so we include them all
k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
if arch == "LlamaForCausalLM"
},
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"MistralModel": ("llama", "LlamaForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
@ -125,7 +126,7 @@ _EMBEDDING_MODELS = {
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
}
_CROSS_ENCODER_MODELS = {
@ -208,6 +209,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
@dataclass(frozen=True)
class _ModelInfo:
architecture: str
is_text_generation_model: bool
is_embedding_model: bool
supports_cross_encoding: bool
@ -218,9 +220,19 @@ class _ModelInfo:
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
is_embedding_model_ = is_embedding_model(model)
if not is_embedding_model_:
try:
as_embedding_model(model)
except Exception:
pass
else:
is_embedding_model_ = True
return _ModelInfo(
architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model),
is_embedding_model=is_embedding_model(model),
is_embedding_model=is_embedding_model_,
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),
@ -399,13 +411,13 @@ class _ModelRegistry:
def inspect_model_cls(
self,
architectures: Union[str, List[str]],
) -> _ModelInfo:
) -> Tuple[_ModelInfo, str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return model_info
return (model_info, arch)
return self._raise_for_unsupported(architectures)
@ -426,39 +438,50 @@ class _ModelRegistry:
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).is_text_generation_model
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_text_generation_model
def is_embedding_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).is_embedding_model
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_embedding_model
def is_cross_encoder_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_cross_encoding
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_cross_encoding
def is_multimodal_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_multimodal
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal
def is_pp_supported_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_pp
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_pp
def model_has_inner_state(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).has_inner_state
def model_has_inner_state(
self,
architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.has_inner_state
def is_attention_free_model(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).is_attention_free
def is_attention_free_model(
self,
architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_attention_free
ModelRegistry = _ModelRegistry({

View File

@ -360,9 +360,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
))
self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if config.text_model_id is not None:
# this prefix is not for initialization, but for loading weights
# note the trailing dot

View File

@ -173,8 +173,15 @@ class AutoWeightsLoader:
module_load_weights = getattr(module, "load_weights", None)
if callable(module_load_weights):
loaded_params = module_load_weights(weights)
yield from map(lambda x: self._get_qualname(base_prefix, x),
loaded_params)
if loaded_params is None:
logger.warning(
"Unable to collect loaded parameters "
"for module %s", module)
else:
yield from map(
lambda x: self._get_qualname(base_prefix, x),
loaded_params,
)
child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False))
@ -232,17 +239,24 @@ class AutoWeightsLoader:
def init_vllm_registered_model(
hf_config: PretrainedConfig,
vllm_config: VllmConfig,
*,
prefix: str = "",
hf_config: Optional[PretrainedConfig] = None,
architectures: Optional[list[str]] = None,
) -> nn.Module:
"""
Helper function to initialize an inner model registered to vLLM,
based on the arguments passed to the outer vLLM model.
"""
from vllm.model_executor.model_loader.loader import _initialize_model
if hf_config is not None:
vllm_config = vllm_config.with_hf_config(hf_config)
return _initialize_model(vllm_config, prefix)
return _initialize_model(vllm_config=vllm_config,
prefix=prefix,
architectures=architectures)
@overload

View File

@ -7,7 +7,7 @@ from torch import nn
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.utils import (get_allowed_kwarg_only_overrides,
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs)
if TYPE_CHECKING:
@ -54,8 +54,8 @@ class MultiModalPlugin(ABC):
"""
def __init__(self) -> None:
self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]()
self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]()
@abstractmethod
def get_data_key(self) -> str:

View File

@ -9,6 +9,7 @@ from typing_extensions import TypeAlias
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import ClassRegistry
from .audio import AudioPlugin
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
@ -62,8 +63,8 @@ class MultiModalRegistry:
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
self._plugins = {p.get_data_key(): p for p in plugins}
self._processor_factories: Dict[Type[nn.Module],
MultiModalProcessorFactory] = {}
self._processor_factories = ClassRegistry[nn.Module,
MultiModalProcessorFactory]()
# This is used for non-multimodal models
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}

View File

@ -20,7 +20,7 @@ import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections import defaultdict
from collections import UserDict, defaultdict
from collections.abc import Iterable, Mapping
from functools import lru_cache, partial, wraps
from platform import uname
@ -1517,13 +1517,13 @@ class AtomicCounter:
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping, Generic[T]):
class LazyDict(Mapping[str, T], Generic[T]):
def __init__(self, factory: Dict[str, Callable[[], T]]):
self._factory = factory
self._dict: Dict[str, T] = {}
def __getitem__(self, key) -> T:
def __getitem__(self, key: str) -> T:
if key not in self._dict:
if key not in self._factory:
raise KeyError(key)
@ -1540,6 +1540,22 @@ class LazyDict(Mapping, Generic[T]):
return len(self._factory)
class ClassRegistry(UserDict[type[T], _V]):
def __getitem__(self, key: type[T]) -> _V:
for cls in key.mro():
if cls in self.data:
return self.data[cls]
raise KeyError(key)
def __contains__(self, key: object) -> bool:
if not isinstance(key, type):
return False
return any(cls in self.data for cls in key.mro())
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""
Create a weak reference to a tensor.