2023-05-14 22:32:38 -07:00
|
|
|
"""Utilities for selecting and loading models."""
|
2023-09-06 23:39:37 -07:00
|
|
|
import contextlib
|
2023-05-23 17:58:51 -07:00
|
|
|
from typing import Type
|
|
|
|
|
2023-02-23 21:31:39 +00:00
|
|
|
import torch
|
2023-02-13 09:36:12 +00:00
|
|
|
import torch.nn as nn
|
2023-05-20 13:06:59 -07:00
|
|
|
from transformers import PretrainedConfig
|
2023-02-13 09:36:12 +00:00
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
from vllm.config import ModelConfig
|
2023-07-03 13:12:35 -07:00
|
|
|
from vllm.model_executor.models import * # pylint: disable=wildcard-import
|
2023-09-16 00:03:37 -07:00
|
|
|
from vllm.model_executor.weight_utils import (get_quant_config,
|
|
|
|
initialize_dummy_weights)
|
2023-02-13 09:36:12 +00:00
|
|
|
|
2023-05-09 15:46:42 -07:00
|
|
|
# TODO(woosuk): Lazy-load the model classes.
|
|
|
|
_MODEL_REGISTRY = {
|
2023-08-22 15:13:36 +08:00
|
|
|
"AquilaModel": AquilaForCausalLM,
|
2023-08-02 13:22:51 +08:00
|
|
|
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
|
|
|
|
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
|
2023-07-03 13:12:35 -07:00
|
|
|
"BloomForCausalLM": BloomForCausalLM,
|
2023-08-02 14:04:39 -07:00
|
|
|
"FalconForCausalLM": FalconForCausalLM,
|
2023-05-09 15:46:42 -07:00
|
|
|
"GPT2LMHeadModel": GPT2LMHeadModel,
|
2023-06-22 19:49:27 +02:00
|
|
|
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
2023-07-08 20:55:16 -04:00
|
|
|
"GPTJForCausalLM": GPTJForCausalLM,
|
2023-05-09 15:46:42 -07:00
|
|
|
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
|
2023-08-09 07:35:06 +08:00
|
|
|
"InternLMForCausalLM": InternLMForCausalLM,
|
2023-05-09 15:46:42 -07:00
|
|
|
"LlamaForCausalLM": LlamaForCausalLM,
|
2023-07-03 16:47:53 -07:00
|
|
|
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
|
|
|
"MPTForCausalLM": MPTForCausalLM,
|
2023-05-09 15:46:42 -07:00
|
|
|
"OPTForCausalLM": OPTForCausalLM,
|
2023-08-09 04:50:38 +08:00
|
|
|
"QWenLMHeadModel": QWenLMHeadModel,
|
2023-08-02 14:04:39 -07:00
|
|
|
"RWForCausalLM": FalconForCausalLM,
|
2023-02-13 09:36:12 +00:00
|
|
|
}
|
|
|
|
|
2023-09-16 00:03:37 -07:00
|
|
|
# FIXME(woosuk): Remove this once all models support quantization.
|
|
|
|
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
|
|
|
|
LlamaForCausalLM,
|
|
|
|
]
|
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2023-09-06 23:39:37 -07:00
|
|
|
@contextlib.contextmanager
|
|
|
|
def _set_default_torch_dtype(dtype: torch.dtype):
|
|
|
|
"""Sets the default torch dtype to the given dtype."""
|
|
|
|
old_dtype = torch.get_default_dtype()
|
|
|
|
torch.set_default_dtype(dtype)
|
|
|
|
yield
|
|
|
|
torch.set_default_dtype(old_dtype)
|
|
|
|
|
|
|
|
|
2023-05-23 17:58:51 -07:00
|
|
|
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
2023-05-09 15:46:42 -07:00
|
|
|
architectures = getattr(config, "architectures", [])
|
|
|
|
for arch in architectures:
|
|
|
|
if arch in _MODEL_REGISTRY:
|
|
|
|
return _MODEL_REGISTRY[arch]
|
|
|
|
raise ValueError(
|
|
|
|
f"Model architectures {architectures} are not supported for now. "
|
2023-07-03 11:31:55 -07:00
|
|
|
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
|
2023-05-09 15:46:42 -07:00
|
|
|
|
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
def get_model(model_config: ModelConfig) -> nn.Module:
|
|
|
|
model_class = _get_model_architecture(model_config.hf_config)
|
2023-09-16 00:03:37 -07:00
|
|
|
|
|
|
|
# Get the quantization config.
|
|
|
|
quant_config = None
|
|
|
|
if model_config.quantization is not None:
|
|
|
|
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
|
|
|
raise ValueError(
|
|
|
|
f"Quantization is not supported for {model_class}.")
|
|
|
|
quant_config = get_quant_config(model_config.quantization,
|
|
|
|
model_config.model,
|
|
|
|
model_config.download_dir)
|
2023-09-18 12:02:01 -07:00
|
|
|
capability = torch.cuda.get_device_capability()
|
|
|
|
capability = capability[0] * 10 + capability[1]
|
|
|
|
if capability < quant_config.get_min_capability():
|
|
|
|
raise ValueError(
|
|
|
|
f"The quantization method {model_config.quantization} is not "
|
|
|
|
"supported for the current GPU. "
|
|
|
|
f"Minimum capability: {quant_config.get_min_capability()}. "
|
|
|
|
f"Current capability: {capability}.")
|
2023-09-16 00:03:37 -07:00
|
|
|
supported_dtypes = quant_config.get_supported_act_dtypes()
|
|
|
|
if model_config.dtype not in supported_dtypes:
|
|
|
|
raise ValueError(
|
|
|
|
f"{model_config.dtype} is not supported for quantization "
|
|
|
|
f"method {model_config.quantization}. Supported dtypes: "
|
|
|
|
f"{supported_dtypes}")
|
|
|
|
|
2023-09-06 23:39:37 -07:00
|
|
|
with _set_default_torch_dtype(model_config.dtype):
|
|
|
|
# Create a model instance.
|
|
|
|
# The weights will be initialized as empty tensors.
|
2023-09-16 00:03:37 -07:00
|
|
|
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
|
|
|
model = model_class(model_config.hf_config, quant_config)
|
|
|
|
else:
|
|
|
|
model = model_class(model_config.hf_config)
|
2023-09-07 15:49:52 -07:00
|
|
|
if model_config.load_format == "dummy":
|
2023-09-06 23:39:37 -07:00
|
|
|
model = model.cuda()
|
|
|
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
|
|
|
# random values to the weights.
|
|
|
|
initialize_dummy_weights(model)
|
|
|
|
else:
|
|
|
|
# Load the weights from the cached or downloaded files.
|
|
|
|
model.load_weights(model_config.model, model_config.download_dir,
|
2023-09-14 06:20:02 +08:00
|
|
|
model_config.load_format, model_config.revision)
|
2023-09-06 23:39:37 -07:00
|
|
|
model = model.cuda()
|
2023-05-20 13:06:59 -07:00
|
|
|
return model.eval()
|