[Model] Explicit interface for vLLM models and support OOT embedding models (#9108)

This commit is contained in:
Cyrus Leung 2024-10-07 14:10:35 +08:00 committed by GitHub
parent 18b296fdb2
commit 8c6de96ea1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 342 additions and 37 deletions

View File

@ -871,6 +871,7 @@ def num_gpus_available():
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt") _dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava") _dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
@pytest.fixture @pytest.fixture
@ -909,3 +910,22 @@ def dummy_llava_path():
with open(json_path, "w") as f: with open(json_path, "w") as f:
json.dump(config, f) json.dump(config, f)
return _dummy_llava_path return _dummy_llava_path
@pytest.fixture
def dummy_gemma2_embedding_path():
json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
if not os.path.exists(_dummy_gemma2_embedding_path):
snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
local_dir=_dummy_gemma2_embedding_path,
ignore_patterns=[
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
"*.msgpack"
])
assert os.path.exists(json_path)
with open(json_path, "r") as f:
config = json.load(f)
config["architectures"] = ["MyGemma2Embedding"]
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_gemma2_embedding_path

View File

@ -2,7 +2,7 @@ import os
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, PoolingParams, SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from ..utils import fork_new_process_for_each_test from ..utils import fork_new_process_for_each_test
@ -17,7 +17,7 @@ def test_plugin(dummy_opt_path):
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_oot_registration(dummy_opt_path): def test_oot_registration_text_generation(dummy_opt_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model" os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = ["Hello, my name is", "The text does not matter"] prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
@ -32,11 +32,23 @@ def test_oot_registration(dummy_opt_path):
assert rest == "" assert rest == ""
@fork_new_process_for_each_test
def test_oot_registration_embedding(dummy_gemma2_embedding_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = PoolingParams()
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
outputs = llm.encode(prompts, sampling_params)
for output in outputs:
assert all(v == 0 for v in output.outputs.embedding)
image = ImageAsset("cherry_blossom").pil_image.convert("RGB") image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_oot_multimodal_registration(dummy_llava_path): def test_oot_registration_multimodal(dummy_llava_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model" os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = [{ prompts = [{
"prompt": "What's in the image?<image>", "prompt": "What's in the image?<image>",

View File

@ -3,7 +3,14 @@ import warnings
import pytest import pytest
import torch.cuda import torch.cuda
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import (is_embedding_model,
is_text_generation_model,
supports_multimodal)
from vllm.model_executor.models.registry import (_EMBEDDING_MODELS,
_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS,
ModelRegistry)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import fork_new_process_for_each_test from ..utils import fork_new_process_for_each_test
@ -12,7 +19,20 @@ from ..utils import fork_new_process_for_each_test
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs()) @pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch): def test_registry_imports(model_arch):
# Ensure all model classes can be imported successfully # Ensure all model classes can be imported successfully
ModelRegistry.resolve_model_cls(model_arch) model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)
if model_arch in _SPECULATIVE_DECODING_MODELS:
pass # Ignore these models which do not have a unified format
else:
assert is_text_generation_model(model_cls) is (
model_arch in _TEXT_GENERATION_MODELS
or model_arch in _MULTIMODAL_MODELS)
assert is_embedding_model(model_cls) is (model_arch
in _EMBEDDING_MODELS)
assert supports_multimodal(model_cls) is (model_arch
in _MULTIMODAL_MODELS)
@fork_new_process_for_each_test @fork_new_process_for_each_test

View File

@ -9,6 +9,12 @@ def register():
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM) ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)
# Test passing lazy model # Test passing lazy model
if "MyGemma2Embedding" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model(
"MyGemma2Embedding",
"vllm_add_dummy_model.my_gemma_embedding:MyGemma2Embedding",
)
if "MyLlava" not in ModelRegistry.get_supported_archs(): if "MyLlava" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyLlava", ModelRegistry.register_model("MyLlava",
"vllm_add_dummy_model.my_llava:MyLlava") "vllm_add_dummy_model.my_llava:MyLlava")

View File

@ -0,0 +1,34 @@
from typing import List, Optional, Union
import torch
from vllm.attention import AttentionMetadata
from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel
from vllm.sequence import IntermediateTensors
class MyGemma2Embedding(Gemma2EmbeddingModel):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = super().forward(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
if isinstance(hidden_states, IntermediateTensors):
return hidden_states
# Return all-zero embeddings
return torch.zeros_like(hidden_states)

View File

@ -1,10 +1,17 @@
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
SupportsPP, has_inner_state, supports_lora, SupportsPP, has_inner_state, supports_lora,
supports_multimodal, supports_pp) supports_multimodal, supports_pp)
from .interfaces_base import (VllmModelForEmbedding,
VllmModelForTextGeneration, is_embedding_model,
is_text_generation_model)
from .registry import ModelRegistry from .registry import ModelRegistry
__all__ = [ __all__ = [
"ModelRegistry", "ModelRegistry",
"VllmModelForEmbedding",
"is_embedding_model",
"VllmModelForTextGeneration",
"is_text_generation_model",
"HasInnerState", "HasInnerState",
"has_inner_state", "has_inner_state",
"SupportsLoRA", "SupportsLoRA",

View File

@ -1,4 +1,3 @@
import inspect
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable) Protocol, Type, Union, overload, runtime_checkable)
@ -6,9 +5,9 @@ import torch
from typing_extensions import TypeIs from typing_extensions import TypeIs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import supports_kw
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -142,9 +141,7 @@ def supports_lora(
return result return result
def _supports_lora( def _supports_lora(model: Union[Type[object], object]) -> bool:
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _SupportsLoRAType) return isinstance(model, _SupportsLoRAType)
@ -175,10 +172,7 @@ class SupportsPP(Protocol):
def forward( def forward(
self, self,
input_ids: torch.Tensor, *,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
intermediate_tensors: Optional["IntermediateTensors"], intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]: ) -> Union[torch.Tensor, "IntermediateTensors"]:
""" """
@ -205,10 +199,7 @@ class _SupportsPPType(Protocol):
def forward( def forward(
self, self,
input_ids: torch.Tensor, *,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
intermediate_tensors: Optional["IntermediateTensors"], intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]: ) -> Union[torch.Tensor, "IntermediateTensors"]:
... ...
@ -257,24 +248,19 @@ def supports_pp(
return supports_attributes and supports_inspect return supports_attributes and supports_inspect
def _supports_pp_attributes( def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _SupportsPPType) return isinstance(model, _SupportsPPType)
return isinstance(model, SupportsPP) return isinstance(model, SupportsPP)
def _supports_pp_inspect( def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
model_forward = getattr(model, "forward", None) model_forward = getattr(model, "forward", None)
if not callable(model_forward): if not callable(model_forward):
return False return False
forward_params = inspect.signature(model_forward).parameters return supports_kw(model_forward, "intermediate_tensors")
return "intermediate_tensors" in forward_params
@runtime_checkable @runtime_checkable

View File

@ -0,0 +1,191 @@
from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union,
overload, runtime_checkable)
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger
from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
logger = init_logger(__name__)
# The type of HF config
C_co = TypeVar("C_co", bound=PretrainedConfig, covariant=True)
# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
# which has T = List[torch.Tensor]
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes
@runtime_checkable
class VllmModel(Protocol[C_co, T_co]):
def __init__(
self,
config: C_co,
*,
cache_config: Optional["CacheConfig"],
quant_config: Optional["QuantizationConfig"],
) -> None:
...
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
) -> T_co:
...
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
model_init = model.__init__
vllm_kws = ("cache_config", "quant_config")
missing_kws = tuple(kw for kw in vllm_kws
if not supports_kw(model_init, kw))
if missing_kws and (isinstance(model, type)
and issubclass(model, nn.Module)):
logger.warning(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s",
model,
missing_kws,
)
return len(missing_kws) == 0
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata")
missing_kws = tuple(kw for kw in vllm_kws
if not supports_kw(model_forward, kw))
if missing_kws and (isinstance(model, type)
and issubclass(model, nn.Module)):
logger.warning(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s",
model,
missing_kws,
)
return len(missing_kws) == 0
@overload
def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]:
...
@overload
def is_vllm_model(model: object) -> TypeIs[VllmModel]:
...
def is_vllm_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]:
return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
@runtime_checkable
class VllmModelForTextGeneration(VllmModel[C_co, T], Protocol[C_co, T]):
def compute_logits(
self,
hidden_states: T,
sampling_metadata: "SamplingMetadata",
) -> Optional[T]:
"""Return `None` if TP rank > 0."""
...
def sample(
self,
logits: T,
sampling_metadata: "SamplingMetadata",
) -> "SamplerOutput":
"""Only called on TP rank 0."""
...
@overload
def is_text_generation_model(
model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]:
...
@overload
def is_text_generation_model(
model: object) -> TypeIs[VllmModelForTextGeneration]:
...
def is_text_generation_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForTextGeneration]],
TypeIs[VllmModelForTextGeneration]]:
if not is_vllm_model(model):
return False
if isinstance(model, type):
return isinstance(model, VllmModelForTextGeneration)
return isinstance(model, VllmModelForTextGeneration)
@runtime_checkable
class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
def pooler(
self,
hidden_states: T,
pooling_metadata: "PoolingMetadata",
) -> "PoolerOutput":
"""Only called on TP rank 0."""
...
@overload
def is_embedding_model(
model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]:
...
@overload
def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]:
...
def is_embedding_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]:
if not is_vllm_model(model):
return False
if isinstance(model, type):
return isinstance(model, VllmModelForEmbedding)
return isinstance(model, VllmModelForEmbedding)

View File

@ -12,10 +12,12 @@ from vllm.logger import init_logger
from vllm.utils import is_hip from vllm.utils import is_hip
from .interfaces import supports_multimodal, supports_pp from .interfaces import supports_multimodal, supports_pp
from .interfaces_base import is_embedding_model, is_text_generation_model
logger = init_logger(__name__) logger = init_logger(__name__)
_GENERATION_MODELS = { _TEXT_GENERATION_MODELS = {
# [Decoder-only]
"AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
@ -74,10 +76,9 @@ _GENERATION_MODELS = {
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"SolarForCausalLM": ("solar", "SolarForCausalLM"), "SolarForCausalLM": ("solar", "SolarForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
# NOTE: The below models are for speculative decoding only # [Encoder-decoder]
"MedusaModel": ("medusa", "Medusa"), "BartModel": ("bart", "BartForConditionalGeneration"),
"EAGLEModel": ("eagle", "EAGLE"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
} }
_EMBEDDING_MODELS = { _EMBEDDING_MODELS = {
@ -114,16 +115,18 @@ _MULTIMODAL_MODELS = {
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration": ("mllama",
"MllamaForConditionalGeneration"), "MllamaForConditionalGeneration"),
} }
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"), _SPECULATIVE_DECODING_MODELS = {
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), "EAGLEModel": ("eagle", "EAGLE"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
} }
_MODELS = { _MODELS = {
**_GENERATION_MODELS, **_TEXT_GENERATION_MODELS,
**_EMBEDDING_MODELS, **_EMBEDDING_MODELS,
**_MULTIMODAL_MODELS, **_MULTIMODAL_MODELS,
**_CONDITIONAL_GENERATION_MODELS, **_SPECULATIVE_DECODING_MODELS,
} }
# Architecture -> type or (module, class). # Architecture -> type or (module, class).
@ -317,6 +320,19 @@ class ModelRegistry:
return result.returncode == 0 return result.returncode == 0
@staticmethod
def is_text_generation_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
is_txt_gen = partial(ModelRegistry._check_stateless,
is_text_generation_model,
default=False)
return any(is_txt_gen(arch) for arch in architectures)
@staticmethod @staticmethod
def is_embedding_model(architectures: Union[str, List[str]]) -> bool: def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str): if isinstance(architectures, str):
@ -324,7 +340,11 @@ class ModelRegistry:
if not architectures: if not architectures:
logger.warning("No model architectures are specified") logger.warning("No model architectures are specified")
return any(arch in _EMBEDDING_MODELS for arch in architectures) is_emb = partial(ModelRegistry._check_stateless,
is_embedding_model,
default=False)
return any(is_emb(arch) for arch in architectures)
@staticmethod @staticmethod
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:

View File

@ -1277,6 +1277,15 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
return await task(*args, **kwargs) return await task(*args, **kwargs)
def supports_kw(callable: Callable[..., object], kw_name: str) -> bool:
params = inspect.signature(callable).parameters
if kw_name in params:
return True
return any(param.kind == inspect.Parameter.VAR_KEYWORD
for param in params.values())
def get_allowed_kwarg_only_overrides( def get_allowed_kwarg_only_overrides(
callable: Callable[..., object], callable: Callable[..., object],
overrides: Optional[Dict[str, Any]], overrides: Optional[Dict[str, Any]],