vllm/vllm/model_executor/model_loader.py

53 lines
1.9 KiB
Python
Raw Normal View History

"""Utilities for selecting and loading models."""
from typing import Type
2023-02-23 21:31:39 +00:00
import torch
2023-02-13 09:36:12 +00:00
import torch.nn as nn
2023-05-20 13:06:59 -07:00
from transformers import PretrainedConfig
2023-02-13 09:36:12 +00:00
2023-06-17 03:07:40 -07:00
from vllm.config import ModelConfig
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM,
GPTNeoXForCausalLM, LlamaForCausalLM,
OPTForCausalLM)
2023-06-17 03:07:40 -07:00
from vllm.model_executor.weight_utils import initialize_dummy_weights
2023-02-13 09:36:12 +00:00
2023-05-09 15:46:42 -07:00
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
"GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
2023-05-09 15:46:42 -07:00
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM,
2023-05-09 15:46:42 -07:00
"OPTForCausalLM": OPTForCausalLM,
2023-02-13 09:36:12 +00:00
}
2023-05-20 13:06:59 -07:00
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
2023-05-09 15:46:42 -07:00
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MODEL_REGISTRY:
return _MODEL_REGISTRY[arch]
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
2023-05-09 15:46:42 -07:00
2023-05-20 13:06:59 -07:00
def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config)
torch.set_default_dtype(model_config.dtype)
2023-05-09 15:46:42 -07:00
# Create a model instance.
# The weights will be initialized as empty tensors.
2023-05-20 13:06:59 -07:00
model = model_class(model_config.hf_config)
if model_config.use_dummy_weights:
2023-05-09 15:46:42 -07:00
model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
else:
# Load the weights from the cached or downloaded files.
model.load_weights(model_config.model, model_config.download_dir,
model_config.use_np_weights)
2023-05-09 15:46:42 -07:00
model = model.cuda()
2023-05-20 13:06:59 -07:00
return model.eval()