vllm/cacheflow/model_executor/model_loader.py

81 lines
2.8 KiB
Python
Raw Normal View History

"""Utilities for selecting and loading models."""
from typing import Optional
2023-02-23 21:31:39 +00:00
import torch
2023-02-13 09:36:12 +00:00
import torch.nn as nn
from transformers import AutoConfig, PretrainedConfig
2023-02-13 09:36:12 +00:00
2023-05-09 15:30:12 -07:00
from cacheflow.model_executor.models import (
GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
from cacheflow.model_executor.utils import get_torch_dtype
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
2023-02-13 09:36:12 +00:00
2023-05-09 15:46:42 -07:00
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
"GPT2LMHeadModel": GPT2LMHeadModel,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
2023-02-13 09:36:12 +00:00
}
2023-05-09 15:46:42 -07:00
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MODEL_REGISTRY:
return _MODEL_REGISTRY[arch]
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
)
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
2023-05-09 15:46:42 -07:00
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
2023-05-06 02:12:12 -07:00
# because config.torch_dtype can be None.
2023-05-09 15:46:42 -07:00
config_dtype = getattr(config, "torch_dtype", None)
2023-05-06 02:12:12 -07:00
if config_dtype is None:
config_dtype = torch.float32
2023-05-09 15:46:42 -07:00
if dtype == "default":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
torch_dtype = get_torch_dtype(dtype)
if torch_dtype != config_dtype and config_dtype != torch.float32:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
raise ValueError(
2023-05-09 15:46:42 -07:00
f"Cannot use {torch_dtype} for {config_dtype} model.")
return torch_dtype
2023-02-23 21:31:39 +00:00
def get_model(
model_name: str,
dtype: str,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
2023-02-23 21:31:39 +00:00
) -> nn.Module:
2023-03-22 04:45:42 +08:00
config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
torch.set_default_dtype(torch_dtype)
2023-05-09 15:46:42 -07:00
model_class = _get_model_architecture(config)
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(config)
if use_dummy_weights:
model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
else:
# Load the weights from the cached or downloaded files.
model.load_weights(model_name, cache_dir, use_np_cache)
model = model.cuda()
return model.eval(), torch_dtype