457 lines
16 KiB
Python
Raw Normal View History

import importlib
import pickle
import subprocess
import sys
import tempfile
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
import cloudpickle
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip
2024-10-11 11:40:06 -04:00
from .interfaces import (has_inner_state, is_attention_free,
supports_multimodal, supports_pp)
from .interfaces_base import is_embedding_model, is_text_generation_model
logger = init_logger(__name__)
# yapf: disable
_TEXT_GENERATION_MODELS = {
# [Decoder-only]
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
# ChatGLMModel supports multimodal
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
2024-10-11 11:40:06 -04:00
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
2024-10-21 20:50:16 +04:00
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
# QWenLMHeadModel supports multimodal
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
}
_EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
# [Multimodal]
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
}
_MULTIMODAL_MODELS = {
# [Decoder-only]
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
# [Encoder-decoder]
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
}
_SPECULATIVE_DECODING_MODELS = {
"EAGLEModel": ("eagle", "EAGLE"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}
# yapf: enable
_VLLM_MODELS = {
**_TEXT_GENERATION_MODELS,
**_EMBEDDING_MODELS,
**_MULTIMODAL_MODELS,
**_SPECULATIVE_DECODING_MODELS,
}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"Qwen2ForCausalLM":
_ROCM_SWA_REASON,
"MistralForCausalLM":
_ROCM_SWA_REASON,
"MixtralForCausalLM":
_ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}
@dataclass(frozen=True)
class _ModelInfo:
is_text_generation_model: bool
is_embedding_model: bool
supports_multimodal: bool
supports_pp: bool
2024-10-11 11:40:06 -04:00
has_inner_state: bool
is_attention_free: bool
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
return _ModelInfo(
is_text_generation_model=is_text_generation_model(model),
is_embedding_model=is_embedding_model(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),
2024-10-11 11:40:06 -04:00
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),
)
class _BaseRegisteredModel(ABC):
@abstractmethod
def inspect_model_cls(self) -> _ModelInfo:
raise NotImplementedError
@abstractmethod
def load_model_cls(self) -> Type[nn.Module]:
raise NotImplementedError
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
"""
Represents a model that has already been imported in the main process.
"""
interfaces: _ModelInfo
model_cls: Type[nn.Module]
@staticmethod
def from_model_cls(model_cls: Type[nn.Module]):
return _RegisteredModel(
interfaces=_ModelInfo.from_model_cls(model_cls),
model_cls=model_cls,
)
def inspect_model_cls(self) -> _ModelInfo:
return self.interfaces
def load_model_cls(self) -> Type[nn.Module]:
return self.model_cls
@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
"""
Represents a model that has not been imported in the main process.
"""
module_name: str
class_name: str
# Performed in another process to avoid initializing CUDA
def inspect_model_cls(self) -> _ModelInfo:
return _run_in_subprocess(
lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
def load_model_cls(self) -> Type[nn.Module]:
mod = importlib.import_module(self.module_name)
return getattr(mod, self.class_name)
@lru_cache(maxsize=128)
def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(f"Model architecture '{model_arch}' is not "
"supported by ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
logger.warning(
"Model architecture '%s' is partially "
"supported by ROCm: %s", model_arch, msg)
try:
return model.load_model_cls()
except Exception:
logger.exception("Error in loading model architecture '%s'",
model_arch)
return None
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
try:
return model.inspect_model_cls()
except Exception:
logger.exception("Error in inspecting model architecture '%s'",
model_arch)
return None
@dataclass
class _ModelRegistry:
# Keyed by model_arch
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
def get_supported_archs(self) -> List[str]:
return list(self.models.keys())
def register_model(
self,
model_arch: str,
model_cls: Union[Type[nn.Module], str],
) -> None:
"""
Register an external model to be used in vLLM.
:code:`model_cls` can be either:
- A :class:`torch.nn.Module` class directly referencing the model.
- A string in the format :code:`<module>:<class>` which can be used to
lazily import the model. This is useful to avoid initializing CUDA
when importing the model and thus the related error
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
"""
if model_arch in self.models:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls)
if isinstance(model_cls, str):
split_str = model_cls.split(":")
if len(split_str) != 2:
msg = "Expected a string in the format `<module>:<class>`"
raise ValueError(msg)
model = _LazyRegisteredModel(*split_str)
else:
model = _RegisteredModel.from_model_cls(model_cls)
self.models[model_arch] = model
def _raise_for_unsupported(self, architectures: List[str]):
all_supported_archs = self.get_supported_archs()
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {all_supported_archs}")
def _try_load_model_cls(self,
model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in self.models:
return None
return _try_load_model_cls(model_arch, self.models[model_arch])
def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
if model_arch not in self.models:
return None
return _try_inspect_model_cls(model_arch, self.models[model_arch])
def _normalize_archs(
self,
architectures: Union[str, List[str]],
) -> List[str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
return architectures
def inspect_model_cls(
self,
architectures: Union[str, List[str]],
) -> _ModelInfo:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return model_info
return self._raise_for_unsupported(architectures)
def resolve_model_cls(
self,
architectures: Union[str, List[str]],
) -> Tuple[Type[nn.Module], str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
return self._raise_for_unsupported(architectures)
def is_text_generation_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).is_text_generation_model
def is_embedding_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).is_embedding_model
def is_multimodal_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_multimodal
def is_pp_supported_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_pp
2024-10-11 11:40:06 -04:00
def model_has_inner_state(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).has_inner_state
def is_attention_free_model(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).is_attention_free
ModelRegistry = _ModelRegistry({
model_arch: _LazyRegisteredModel(
module_name=f"vllm.model_executor.models.{mod_relname}",
class_name=cls_name,
)
for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})
_T = TypeVar("_T")
def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
with tempfile.NamedTemporaryFile() as output_file:
# `cloudpickle` allows pickling lambda functions directly
input_bytes = cloudpickle.dumps((fn, output_file.name))
# cannot use `sys.executable __file__` here because the script
# contains relative imports
returned = subprocess.run(
[sys.executable, "-m", "vllm.model_executor.models.registry"],
input=input_bytes,
capture_output=True)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(f"Error raised in subprocess:\n"
f"{returned.stderr.decode()}") from e
with open(output_file.name, "rb") as f:
return pickle.load(f)
def _run() -> None:
# Setup plugins
from vllm.plugins import load_general_plugins
load_general_plugins()
fn, output_file = pickle.loads(sys.stdin.buffer.read())
result = fn()
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))
if __name__ == "__main__":
_run()