[Model] Explicit interface for vLLM models and support OOT embedding models (#9108)
This commit is contained in:
parent
18b296fdb2
commit
8c6de96ea1
@ -871,6 +871,7 @@ def num_gpus_available():
|
|||||||
temp_dir = tempfile.gettempdir()
|
temp_dir = tempfile.gettempdir()
|
||||||
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
|
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
|
||||||
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
|
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
|
||||||
|
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -909,3 +910,22 @@ def dummy_llava_path():
|
|||||||
with open(json_path, "w") as f:
|
with open(json_path, "w") as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
return _dummy_llava_path
|
return _dummy_llava_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_gemma2_embedding_path():
|
||||||
|
json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
|
||||||
|
if not os.path.exists(_dummy_gemma2_embedding_path):
|
||||||
|
snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
|
||||||
|
local_dir=_dummy_gemma2_embedding_path,
|
||||||
|
ignore_patterns=[
|
||||||
|
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
||||||
|
"*.msgpack"
|
||||||
|
])
|
||||||
|
assert os.path.exists(json_path)
|
||||||
|
with open(json_path, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
config["architectures"] = ["MyGemma2Embedding"]
|
||||||
|
with open(json_path, "w") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
return _dummy_gemma2_embedding_path
|
||||||
|
@ -2,7 +2,7 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, PoolingParams, SamplingParams
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
|
|
||||||
from ..utils import fork_new_process_for_each_test
|
from ..utils import fork_new_process_for_each_test
|
||||||
@ -17,7 +17,7 @@ def test_plugin(dummy_opt_path):
|
|||||||
|
|
||||||
|
|
||||||
@fork_new_process_for_each_test
|
@fork_new_process_for_each_test
|
||||||
def test_oot_registration(dummy_opt_path):
|
def test_oot_registration_text_generation(dummy_opt_path):
|
||||||
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
|
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
|
||||||
prompts = ["Hello, my name is", "The text does not matter"]
|
prompts = ["Hello, my name is", "The text does not matter"]
|
||||||
sampling_params = SamplingParams(temperature=0)
|
sampling_params = SamplingParams(temperature=0)
|
||||||
@ -32,11 +32,23 @@ def test_oot_registration(dummy_opt_path):
|
|||||||
assert rest == ""
|
assert rest == ""
|
||||||
|
|
||||||
|
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_oot_registration_embedding(dummy_gemma2_embedding_path):
|
||||||
|
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
|
||||||
|
prompts = ["Hello, my name is", "The text does not matter"]
|
||||||
|
sampling_params = PoolingParams()
|
||||||
|
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
|
||||||
|
outputs = llm.encode(prompts, sampling_params)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
assert all(v == 0 for v in output.outputs.embedding)
|
||||||
|
|
||||||
|
|
||||||
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
|
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@fork_new_process_for_each_test
|
@fork_new_process_for_each_test
|
||||||
def test_oot_multimodal_registration(dummy_llava_path):
|
def test_oot_registration_multimodal(dummy_llava_path):
|
||||||
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
|
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
|
||||||
prompts = [{
|
prompts = [{
|
||||||
"prompt": "What's in the image?<image>",
|
"prompt": "What's in the image?<image>",
|
||||||
|
@ -3,7 +3,14 @@ import warnings
|
|||||||
import pytest
|
import pytest
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import (is_embedding_model,
|
||||||
|
is_text_generation_model,
|
||||||
|
supports_multimodal)
|
||||||
|
from vllm.model_executor.models.registry import (_EMBEDDING_MODELS,
|
||||||
|
_MULTIMODAL_MODELS,
|
||||||
|
_SPECULATIVE_DECODING_MODELS,
|
||||||
|
_TEXT_GENERATION_MODELS,
|
||||||
|
ModelRegistry)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import fork_new_process_for_each_test
|
from ..utils import fork_new_process_for_each_test
|
||||||
@ -12,7 +19,20 @@ from ..utils import fork_new_process_for_each_test
|
|||||||
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
|
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
|
||||||
def test_registry_imports(model_arch):
|
def test_registry_imports(model_arch):
|
||||||
# Ensure all model classes can be imported successfully
|
# Ensure all model classes can be imported successfully
|
||||||
ModelRegistry.resolve_model_cls(model_arch)
|
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)
|
||||||
|
|
||||||
|
if model_arch in _SPECULATIVE_DECODING_MODELS:
|
||||||
|
pass # Ignore these models which do not have a unified format
|
||||||
|
else:
|
||||||
|
assert is_text_generation_model(model_cls) is (
|
||||||
|
model_arch in _TEXT_GENERATION_MODELS
|
||||||
|
or model_arch in _MULTIMODAL_MODELS)
|
||||||
|
|
||||||
|
assert is_embedding_model(model_cls) is (model_arch
|
||||||
|
in _EMBEDDING_MODELS)
|
||||||
|
|
||||||
|
assert supports_multimodal(model_cls) is (model_arch
|
||||||
|
in _MULTIMODAL_MODELS)
|
||||||
|
|
||||||
|
|
||||||
@fork_new_process_for_each_test
|
@fork_new_process_for_each_test
|
||||||
|
@ -9,6 +9,12 @@ def register():
|
|||||||
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)
|
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)
|
||||||
|
|
||||||
# Test passing lazy model
|
# Test passing lazy model
|
||||||
|
if "MyGemma2Embedding" not in ModelRegistry.get_supported_archs():
|
||||||
|
ModelRegistry.register_model(
|
||||||
|
"MyGemma2Embedding",
|
||||||
|
"vllm_add_dummy_model.my_gemma_embedding:MyGemma2Embedding",
|
||||||
|
)
|
||||||
|
|
||||||
if "MyLlava" not in ModelRegistry.get_supported_archs():
|
if "MyLlava" not in ModelRegistry.get_supported_archs():
|
||||||
ModelRegistry.register_model("MyLlava",
|
ModelRegistry.register_model("MyLlava",
|
||||||
"vllm_add_dummy_model.my_llava:MyLlava")
|
"vllm_add_dummy_model.my_llava:MyLlava")
|
||||||
|
@ -0,0 +1,34 @@
|
|||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
|
||||||
|
class MyGemma2Embedding(Gemma2EmbeddingModel):
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
hidden_states = super().forward(
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
kv_caches,
|
||||||
|
attn_metadata,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(hidden_states, IntermediateTensors):
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
# Return all-zero embeddings
|
||||||
|
return torch.zeros_like(hidden_states)
|
@ -1,10 +1,17 @@
|
|||||||
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
|
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
|
||||||
SupportsPP, has_inner_state, supports_lora,
|
SupportsPP, has_inner_state, supports_lora,
|
||||||
supports_multimodal, supports_pp)
|
supports_multimodal, supports_pp)
|
||||||
|
from .interfaces_base import (VllmModelForEmbedding,
|
||||||
|
VllmModelForTextGeneration, is_embedding_model,
|
||||||
|
is_text_generation_model)
|
||||||
from .registry import ModelRegistry
|
from .registry import ModelRegistry
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelRegistry",
|
"ModelRegistry",
|
||||||
|
"VllmModelForEmbedding",
|
||||||
|
"is_embedding_model",
|
||||||
|
"VllmModelForTextGeneration",
|
||||||
|
"is_text_generation_model",
|
||||||
"HasInnerState",
|
"HasInnerState",
|
||||||
"has_inner_state",
|
"has_inner_state",
|
||||||
"SupportsLoRA",
|
"SupportsLoRA",
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import inspect
|
|
||||||
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
|
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
|
||||||
Protocol, Type, Union, overload, runtime_checkable)
|
Protocol, Type, Union, overload, runtime_checkable)
|
||||||
|
|
||||||
@ -6,9 +5,9 @@ import torch
|
|||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import supports_kw
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
|
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
@ -142,9 +141,7 @@ def supports_lora(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _supports_lora(
|
def _supports_lora(model: Union[Type[object], object]) -> bool:
|
||||||
model: Union[Type[object], object],
|
|
||||||
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
|
|
||||||
if isinstance(model, type):
|
if isinstance(model, type):
|
||||||
return isinstance(model, _SupportsLoRAType)
|
return isinstance(model, _SupportsLoRAType)
|
||||||
|
|
||||||
@ -175,10 +172,7 @@ class SupportsPP(Protocol):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
*,
|
||||||
position_ids: torch.Tensor,
|
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: "AttentionMetadata",
|
|
||||||
intermediate_tensors: Optional["IntermediateTensors"],
|
intermediate_tensors: Optional["IntermediateTensors"],
|
||||||
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
||||||
"""
|
"""
|
||||||
@ -205,10 +199,7 @@ class _SupportsPPType(Protocol):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
*,
|
||||||
position_ids: torch.Tensor,
|
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: "AttentionMetadata",
|
|
||||||
intermediate_tensors: Optional["IntermediateTensors"],
|
intermediate_tensors: Optional["IntermediateTensors"],
|
||||||
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
||||||
...
|
...
|
||||||
@ -257,24 +248,19 @@ def supports_pp(
|
|||||||
return supports_attributes and supports_inspect
|
return supports_attributes and supports_inspect
|
||||||
|
|
||||||
|
|
||||||
def _supports_pp_attributes(
|
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
|
||||||
model: Union[Type[object], object],
|
|
||||||
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
|
|
||||||
if isinstance(model, type):
|
if isinstance(model, type):
|
||||||
return isinstance(model, _SupportsPPType)
|
return isinstance(model, _SupportsPPType)
|
||||||
|
|
||||||
return isinstance(model, SupportsPP)
|
return isinstance(model, SupportsPP)
|
||||||
|
|
||||||
|
|
||||||
def _supports_pp_inspect(
|
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
|
||||||
model: Union[Type[object], object],
|
|
||||||
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
|
|
||||||
model_forward = getattr(model, "forward", None)
|
model_forward = getattr(model, "forward", None)
|
||||||
if not callable(model_forward):
|
if not callable(model_forward):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
forward_params = inspect.signature(model_forward).parameters
|
return supports_kw(model_forward, "intermediate_tensors")
|
||||||
return "intermediate_tensors" in forward_params
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
191
vllm/model_executor/models/interfaces_base.py
Normal file
191
vllm/model_executor/models/interfaces_base.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union,
|
||||||
|
overload, runtime_checkable)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from typing_extensions import TypeIs, TypeVar
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import supports_kw
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.model_executor.layers.pooler import PoolerOutput
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# The type of HF config
|
||||||
|
C_co = TypeVar("C_co", bound=PretrainedConfig, covariant=True)
|
||||||
|
|
||||||
|
# The type of hidden states
|
||||||
|
# Currently, T = torch.Tensor for all models except for Medusa
|
||||||
|
# which has T = List[torch.Tensor]
|
||||||
|
T = TypeVar("T", default=torch.Tensor)
|
||||||
|
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
|
||||||
|
|
||||||
|
# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
|
||||||
|
# for the base interfaces to avoid breaking OOT registration for existing models
|
||||||
|
# that don't inherit from the base interface classes
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class VllmModel(Protocol[C_co, T_co]):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: C_co,
|
||||||
|
*,
|
||||||
|
cache_config: Optional["CacheConfig"],
|
||||||
|
quant_config: Optional["QuantizationConfig"],
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: "AttentionMetadata",
|
||||||
|
) -> T_co:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
|
||||||
|
model_init = model.__init__
|
||||||
|
vllm_kws = ("cache_config", "quant_config")
|
||||||
|
missing_kws = tuple(kw for kw in vllm_kws
|
||||||
|
if not supports_kw(model_init, kw))
|
||||||
|
|
||||||
|
if missing_kws and (isinstance(model, type)
|
||||||
|
and issubclass(model, nn.Module)):
|
||||||
|
logger.warning(
|
||||||
|
"The model (%s) is missing "
|
||||||
|
"vLLM-specific keywords from its initializer: %s",
|
||||||
|
model,
|
||||||
|
missing_kws,
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(missing_kws) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
|
||||||
|
model_forward = getattr(model, "forward", None)
|
||||||
|
if not callable(model_forward):
|
||||||
|
return False
|
||||||
|
|
||||||
|
vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata")
|
||||||
|
missing_kws = tuple(kw for kw in vllm_kws
|
||||||
|
if not supports_kw(model_forward, kw))
|
||||||
|
|
||||||
|
if missing_kws and (isinstance(model, type)
|
||||||
|
and issubclass(model, nn.Module)):
|
||||||
|
logger.warning(
|
||||||
|
"The model (%s) is missing "
|
||||||
|
"vLLM-specific keywords from its initializer: %s",
|
||||||
|
model,
|
||||||
|
missing_kws,
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(missing_kws) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def is_vllm_model(model: object) -> TypeIs[VllmModel]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def is_vllm_model(
|
||||||
|
model: Union[Type[object], object],
|
||||||
|
) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]:
|
||||||
|
return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class VllmModelForTextGeneration(VllmModel[C_co, T], Protocol[C_co, T]):
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: T,
|
||||||
|
sampling_metadata: "SamplingMetadata",
|
||||||
|
) -> Optional[T]:
|
||||||
|
"""Return `None` if TP rank > 0."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: T,
|
||||||
|
sampling_metadata: "SamplingMetadata",
|
||||||
|
) -> "SamplerOutput":
|
||||||
|
"""Only called on TP rank 0."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def is_text_generation_model(
|
||||||
|
model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def is_text_generation_model(
|
||||||
|
model: object) -> TypeIs[VllmModelForTextGeneration]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def is_text_generation_model(
|
||||||
|
model: Union[Type[object], object],
|
||||||
|
) -> Union[TypeIs[Type[VllmModelForTextGeneration]],
|
||||||
|
TypeIs[VllmModelForTextGeneration]]:
|
||||||
|
if not is_vllm_model(model):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(model, type):
|
||||||
|
return isinstance(model, VllmModelForTextGeneration)
|
||||||
|
|
||||||
|
return isinstance(model, VllmModelForTextGeneration)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: T,
|
||||||
|
pooling_metadata: "PoolingMetadata",
|
||||||
|
) -> "PoolerOutput":
|
||||||
|
"""Only called on TP rank 0."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def is_embedding_model(
|
||||||
|
model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def is_embedding_model(
|
||||||
|
model: Union[Type[object], object],
|
||||||
|
) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]:
|
||||||
|
if not is_vllm_model(model):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(model, type):
|
||||||
|
return isinstance(model, VllmModelForEmbedding)
|
||||||
|
|
||||||
|
return isinstance(model, VllmModelForEmbedding)
|
@ -12,10 +12,12 @@ from vllm.logger import init_logger
|
|||||||
from vllm.utils import is_hip
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
from .interfaces import supports_multimodal, supports_pp
|
from .interfaces import supports_multimodal, supports_pp
|
||||||
|
from .interfaces_base import is_embedding_model, is_text_generation_model
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_GENERATION_MODELS = {
|
_TEXT_GENERATION_MODELS = {
|
||||||
|
# [Decoder-only]
|
||||||
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
||||||
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
||||||
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
||||||
@ -74,10 +76,9 @@ _GENERATION_MODELS = {
|
|||||||
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
||||||
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
||||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||||
# NOTE: The below models are for speculative decoding only
|
# [Encoder-decoder]
|
||||||
"MedusaModel": ("medusa", "Medusa"),
|
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||||
"EAGLEModel": ("eagle", "EAGLE"),
|
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
||||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_EMBEDDING_MODELS = {
|
_EMBEDDING_MODELS = {
|
||||||
@ -114,16 +115,18 @@ _MULTIMODAL_MODELS = {
|
|||||||
"MllamaForConditionalGeneration": ("mllama",
|
"MllamaForConditionalGeneration": ("mllama",
|
||||||
"MllamaForConditionalGeneration"),
|
"MllamaForConditionalGeneration"),
|
||||||
}
|
}
|
||||||
_CONDITIONAL_GENERATION_MODELS = {
|
|
||||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
_SPECULATIVE_DECODING_MODELS = {
|
||||||
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
"EAGLEModel": ("eagle", "EAGLE"),
|
||||||
|
"MedusaModel": ("medusa", "Medusa"),
|
||||||
|
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
**_GENERATION_MODELS,
|
**_TEXT_GENERATION_MODELS,
|
||||||
**_EMBEDDING_MODELS,
|
**_EMBEDDING_MODELS,
|
||||||
**_MULTIMODAL_MODELS,
|
**_MULTIMODAL_MODELS,
|
||||||
**_CONDITIONAL_GENERATION_MODELS,
|
**_SPECULATIVE_DECODING_MODELS,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Architecture -> type or (module, class).
|
# Architecture -> type or (module, class).
|
||||||
@ -317,6 +320,19 @@ class ModelRegistry:
|
|||||||
|
|
||||||
return result.returncode == 0
|
return result.returncode == 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_text_generation_model(architectures: Union[str, List[str]]) -> bool:
|
||||||
|
if isinstance(architectures, str):
|
||||||
|
architectures = [architectures]
|
||||||
|
if not architectures:
|
||||||
|
logger.warning("No model architectures are specified")
|
||||||
|
|
||||||
|
is_txt_gen = partial(ModelRegistry._check_stateless,
|
||||||
|
is_text_generation_model,
|
||||||
|
default=False)
|
||||||
|
|
||||||
|
return any(is_txt_gen(arch) for arch in architectures)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
|
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
|
||||||
if isinstance(architectures, str):
|
if isinstance(architectures, str):
|
||||||
@ -324,7 +340,11 @@ class ModelRegistry:
|
|||||||
if not architectures:
|
if not architectures:
|
||||||
logger.warning("No model architectures are specified")
|
logger.warning("No model architectures are specified")
|
||||||
|
|
||||||
return any(arch in _EMBEDDING_MODELS for arch in architectures)
|
is_emb = partial(ModelRegistry._check_stateless,
|
||||||
|
is_embedding_model,
|
||||||
|
default=False)
|
||||||
|
|
||||||
|
return any(is_emb(arch) for arch in architectures)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
|
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
|
||||||
|
@ -1277,6 +1277,15 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
|||||||
return await task(*args, **kwargs)
|
return await task(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def supports_kw(callable: Callable[..., object], kw_name: str) -> bool:
|
||||||
|
params = inspect.signature(callable).parameters
|
||||||
|
if kw_name in params:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return any(param.kind == inspect.Parameter.VAR_KEYWORD
|
||||||
|
for param in params.values())
|
||||||
|
|
||||||
|
|
||||||
def get_allowed_kwarg_only_overrides(
|
def get_allowed_kwarg_only_overrides(
|
||||||
callable: Callable[..., object],
|
callable: Callable[..., object],
|
||||||
overrides: Optional[Dict[str, Any]],
|
overrides: Optional[Dict[str, Any]],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user