457 lines
16 KiB
Python

import importlib
import pickle
import subprocess
import sys
import tempfile
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
import cloudpickle
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip
from .interfaces import (has_inner_state, is_attention_free,
supports_multimodal, supports_pp)
from .interfaces_base import is_embedding_model, is_text_generation_model
logger = init_logger(__name__)
# yapf: disable
_TEXT_GENERATION_MODELS = {
# [Decoder-only]
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
# ChatGLMModel supports multimodal
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
# QWenLMHeadModel supports multimodal
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
}
_EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
# [Multimodal]
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
}
_MULTIMODAL_MODELS = {
# [Decoder-only]
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
# [Encoder-decoder]
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
}
_SPECULATIVE_DECODING_MODELS = {
"EAGLEModel": ("eagle", "EAGLE"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}
# yapf: enable
_VLLM_MODELS = {
**_TEXT_GENERATION_MODELS,
**_EMBEDDING_MODELS,
**_MULTIMODAL_MODELS,
**_SPECULATIVE_DECODING_MODELS,
}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"Qwen2ForCausalLM":
_ROCM_SWA_REASON,
"MistralForCausalLM":
_ROCM_SWA_REASON,
"MixtralForCausalLM":
_ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}
@dataclass(frozen=True)
class _ModelInfo:
is_text_generation_model: bool
is_embedding_model: bool
supports_multimodal: bool
supports_pp: bool
has_inner_state: bool
is_attention_free: bool
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
return _ModelInfo(
is_text_generation_model=is_text_generation_model(model),
is_embedding_model=is_embedding_model(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),
)
class _BaseRegisteredModel(ABC):
@abstractmethod
def inspect_model_cls(self) -> _ModelInfo:
raise NotImplementedError
@abstractmethod
def load_model_cls(self) -> Type[nn.Module]:
raise NotImplementedError
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
"""
Represents a model that has already been imported in the main process.
"""
interfaces: _ModelInfo
model_cls: Type[nn.Module]
@staticmethod
def from_model_cls(model_cls: Type[nn.Module]):
return _RegisteredModel(
interfaces=_ModelInfo.from_model_cls(model_cls),
model_cls=model_cls,
)
def inspect_model_cls(self) -> _ModelInfo:
return self.interfaces
def load_model_cls(self) -> Type[nn.Module]:
return self.model_cls
@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
"""
Represents a model that has not been imported in the main process.
"""
module_name: str
class_name: str
# Performed in another process to avoid initializing CUDA
def inspect_model_cls(self) -> _ModelInfo:
return _run_in_subprocess(
lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
def load_model_cls(self) -> Type[nn.Module]:
mod = importlib.import_module(self.module_name)
return getattr(mod, self.class_name)
@lru_cache(maxsize=128)
def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(f"Model architecture '{model_arch}' is not "
"supported by ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
logger.warning(
"Model architecture '%s' is partially "
"supported by ROCm: %s", model_arch, msg)
try:
return model.load_model_cls()
except Exception:
logger.exception("Error in loading model architecture '%s'",
model_arch)
return None
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
try:
return model.inspect_model_cls()
except Exception:
logger.exception("Error in inspecting model architecture '%s'",
model_arch)
return None
@dataclass
class _ModelRegistry:
# Keyed by model_arch
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
def get_supported_archs(self) -> List[str]:
return list(self.models.keys())
def register_model(
self,
model_arch: str,
model_cls: Union[Type[nn.Module], str],
) -> None:
"""
Register an external model to be used in vLLM.
:code:`model_cls` can be either:
- A :class:`torch.nn.Module` class directly referencing the model.
- A string in the format :code:`<module>:<class>` which can be used to
lazily import the model. This is useful to avoid initializing CUDA
when importing the model and thus the related error
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
"""
if model_arch in self.models:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls)
if isinstance(model_cls, str):
split_str = model_cls.split(":")
if len(split_str) != 2:
msg = "Expected a string in the format `<module>:<class>`"
raise ValueError(msg)
model = _LazyRegisteredModel(*split_str)
else:
model = _RegisteredModel.from_model_cls(model_cls)
self.models[model_arch] = model
def _raise_for_unsupported(self, architectures: List[str]):
all_supported_archs = self.get_supported_archs()
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {all_supported_archs}")
def _try_load_model_cls(self,
model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in self.models:
return None
return _try_load_model_cls(model_arch, self.models[model_arch])
def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
if model_arch not in self.models:
return None
return _try_inspect_model_cls(model_arch, self.models[model_arch])
def _normalize_archs(
self,
architectures: Union[str, List[str]],
) -> List[str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
return architectures
def inspect_model_cls(
self,
architectures: Union[str, List[str]],
) -> _ModelInfo:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return model_info
return self._raise_for_unsupported(architectures)
def resolve_model_cls(
self,
architectures: Union[str, List[str]],
) -> Tuple[Type[nn.Module], str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
return self._raise_for_unsupported(architectures)
def is_text_generation_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).is_text_generation_model
def is_embedding_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).is_embedding_model
def is_multimodal_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_multimodal
def is_pp_supported_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_pp
def model_has_inner_state(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).has_inner_state
def is_attention_free_model(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).is_attention_free
ModelRegistry = _ModelRegistry({
model_arch: _LazyRegisteredModel(
module_name=f"vllm.model_executor.models.{mod_relname}",
class_name=cls_name,
)
for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})
_T = TypeVar("_T")
def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
with tempfile.NamedTemporaryFile() as output_file:
# `cloudpickle` allows pickling lambda functions directly
input_bytes = cloudpickle.dumps((fn, output_file.name))
# cannot use `sys.executable __file__` here because the script
# contains relative imports
returned = subprocess.run(
[sys.executable, "-m", "vllm.model_executor.models.registry"],
input=input_bytes,
capture_output=True)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(f"Error raised in subprocess:\n"
f"{returned.stderr.decode()}") from e
with open(output_file.name, "rb") as f:
return pickle.load(f)
def _run() -> None:
# Setup plugins
from vllm.plugins import load_general_plugins
load_general_plugins()
fn, output_file = pickle.loads(sys.stdin.buffer.read())
result = fn()
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))
if __name__ == "__main__":
_run()