Enhance model loader (#83)

This commit is contained in:
Woosuk Kwon 2023-05-09 15:46:42 -07:00 committed by GitHub
parent 7c041ab578
commit add055e151
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 42 deletions

View File

@ -12,8 +12,8 @@ from cacheflow.core.scheduler import Scheduler
from cacheflow.frontend.simple_frontend import SimpleFrontend
from cacheflow.logger import init_logger
from cacheflow.model_executor import get_memory_analyzer
from cacheflow.sequence import SequenceGroup
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroup
from cacheflow.utils import get_gpu_memory, get_cpu_memory
from cacheflow.worker.controller import Controller, DeviceID

View File

@ -14,32 +14,51 @@ from cacheflow.model_executor.utils import get_torch_dtype
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
_MODELS = {
'gpt2': GPT2LMHeadModel,
'llama': LlamaForCausalLM,
'opt': OPTForCausalLM,
'stablelm': GPTNeoXForCausalLM,
'pythia': GPTNeoXForCausalLM,
'dolly-v2': GPTNeoXForCausalLM,
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
"GPT2LMHeadModel": GPT2LMHeadModel,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
}
_MEMORY_ANALYZERS = {
'gpt2': GPT2MemoryAnalyzer,
'llama': LlamaMemoryAnalyzer,
'opt': OPTMemoryAnalyzer,
'stablelm': GPTNeoXMemoryAnalyzer,
'pythia': GPTNeoXMemoryAnalyzer,
'dolly-v2': GPTNeoXMemoryAnalyzer,
"GPT2LMHeadModel": GPT2MemoryAnalyzer,
"GPTNeoXForCausalLM": GPTNeoXMemoryAnalyzer,
"LlamaForCausalLM": LlamaMemoryAnalyzer,
"OPTForCausalLM": OPTMemoryAnalyzer,
}
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MODEL_REGISTRY:
return _MODEL_REGISTRY[arch]
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
)
def _get_memory_analyzer(config: PretrainedConfig) -> CacheFlowMemoryAnalyzer:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MEMORY_ANALYZERS:
return _MEMORY_ANALYZERS[arch]
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MEMORY_ANALYZERS.keys())}"
)
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
# NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, 'torch_dtype', None)
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
if dtype == 'default':
if dtype == "default":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
@ -51,7 +70,7 @@ def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
raise ValueError(
f'Cannot use {torch_dtype} for {config_dtype} model.')
f"Cannot use {torch_dtype} for {config_dtype} model.")
return torch_dtype
@ -65,24 +84,21 @@ def get_model(
config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
torch.set_default_dtype(torch_dtype)
for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name:
if use_dummy_weights:
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(config)
model = model.cuda()
# NOTE(woosuk): For precise performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
else:
# Create a model instance.
model = model_class(config)
# Load the weights from the cached or downloaded files.
model.load_weights(model_name, cache_dir, use_np_cache)
model = model.cuda()
return model.eval(), torch_dtype
raise ValueError(f'Unsupported model name: {model_name}')
model_class = _get_model_architecture(config)
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(config)
if use_dummy_weights:
model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
else:
# Load the weights from the cached or downloaded files.
model.load_weights(model_name, cache_dir, use_np_cache)
model = model.cuda()
return model.eval(), torch_dtype
def get_memory_analyzer(
@ -95,9 +111,7 @@ def get_memory_analyzer(
) -> CacheFlowMemoryAnalyzer:
config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
if model_class in model_name:
return memory_analyzer(
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
tensor_parallel_size)
raise ValueError(f'Unsupported model name: {model_name}')
memory_analyzer = _get_memory_analyzer(config)
return memory_analyzer(
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
tensor_parallel_size)