From add055e151f32f89dab5932d25e5285b2fc823f1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 9 May 2023 15:46:42 -0700 Subject: [PATCH] Enhance model loader (#83) --- cacheflow/core/server.py | 2 +- cacheflow/model_executor/model_loader.py | 96 ++++++++++++++---------- 2 files changed, 56 insertions(+), 42 deletions(-) diff --git a/cacheflow/core/server.py b/cacheflow/core/server.py index 9047ccf6..9eb96efd 100644 --- a/cacheflow/core/server.py +++ b/cacheflow/core/server.py @@ -12,8 +12,8 @@ from cacheflow.core.scheduler import Scheduler from cacheflow.frontend.simple_frontend import SimpleFrontend from cacheflow.logger import init_logger from cacheflow.model_executor import get_memory_analyzer -from cacheflow.sequence import SequenceGroup from cacheflow.sampling_params import SamplingParams +from cacheflow.sequence import SequenceGroup from cacheflow.utils import get_gpu_memory, get_cpu_memory from cacheflow.worker.controller import Controller, DeviceID diff --git a/cacheflow/model_executor/model_loader.py b/cacheflow/model_executor/model_loader.py index 1af2c606..5598309e 100644 --- a/cacheflow/model_executor/model_loader.py +++ b/cacheflow/model_executor/model_loader.py @@ -14,32 +14,51 @@ from cacheflow.model_executor.utils import get_torch_dtype from cacheflow.model_executor.weight_utils import initialize_dummy_weights -_MODELS = { - 'gpt2': GPT2LMHeadModel, - 'llama': LlamaForCausalLM, - 'opt': OPTForCausalLM, - 'stablelm': GPTNeoXForCausalLM, - 'pythia': GPTNeoXForCausalLM, - 'dolly-v2': GPTNeoXForCausalLM, +# TODO(woosuk): Lazy-load the model classes. +_MODEL_REGISTRY = { + "GPT2LMHeadModel": GPT2LMHeadModel, + "GPTNeoXForCausalLM": GPTNeoXForCausalLM, + "LlamaForCausalLM": LlamaForCausalLM, + "OPTForCausalLM": OPTForCausalLM, } _MEMORY_ANALYZERS = { - 'gpt2': GPT2MemoryAnalyzer, - 'llama': LlamaMemoryAnalyzer, - 'opt': OPTMemoryAnalyzer, - 'stablelm': GPTNeoXMemoryAnalyzer, - 'pythia': GPTNeoXMemoryAnalyzer, - 'dolly-v2': GPTNeoXMemoryAnalyzer, + "GPT2LMHeadModel": GPT2MemoryAnalyzer, + "GPTNeoXForCausalLM": GPTNeoXMemoryAnalyzer, + "LlamaForCausalLM": LlamaMemoryAnalyzer, + "OPTForCausalLM": OPTMemoryAnalyzer, } +def _get_model_architecture(config: PretrainedConfig) -> nn.Module: + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in _MODEL_REGISTRY: + return _MODEL_REGISTRY[arch] + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {list(_MODEL_REGISTRY.keys())}" + ) + + +def _get_memory_analyzer(config: PretrainedConfig) -> CacheFlowMemoryAnalyzer: + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in _MEMORY_ANALYZERS: + return _MEMORY_ANALYZERS[arch] + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {list(_MEMORY_ANALYZERS.keys())}" + ) + + def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype: - # NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. - config_dtype = getattr(config, 'torch_dtype', None) + config_dtype = getattr(config, "torch_dtype", None) if config_dtype is None: config_dtype = torch.float32 - if dtype == 'default': + if dtype == "default": if config_dtype == torch.float32: # Following the common practice, we use float16 for float32 models. torch_dtype = torch.float16 @@ -51,7 +70,7 @@ def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype: # TODO(woosuk): Allow using float16 for bfloat16 models and # vice versa. Print a warning message and continue. raise ValueError( - f'Cannot use {torch_dtype} for {config_dtype} model.') + f"Cannot use {torch_dtype} for {config_dtype} model.") return torch_dtype @@ -65,24 +84,21 @@ def get_model( config = AutoConfig.from_pretrained(model_name) torch_dtype = _get_dtype(config, dtype) torch.set_default_dtype(torch_dtype) - for model_class_name, model_class in _MODELS.items(): - if model_class_name in model_name: - if use_dummy_weights: - # Create a model instance. - # The weights will be initialized as empty tensors. - model = model_class(config) - model = model.cuda() - # NOTE(woosuk): For precise performance evaluation, we assign - # random values to the weights. - initialize_dummy_weights(model) - else: - # Create a model instance. - model = model_class(config) - # Load the weights from the cached or downloaded files. - model.load_weights(model_name, cache_dir, use_np_cache) - model = model.cuda() - return model.eval(), torch_dtype - raise ValueError(f'Unsupported model name: {model_name}') + model_class = _get_model_architecture(config) + + # Create a model instance. + # The weights will be initialized as empty tensors. + model = model_class(config) + if use_dummy_weights: + model = model.cuda() + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + else: + # Load the weights from the cached or downloaded files. + model.load_weights(model_name, cache_dir, use_np_cache) + model = model.cuda() + return model.eval(), torch_dtype def get_memory_analyzer( @@ -95,9 +111,7 @@ def get_memory_analyzer( ) -> CacheFlowMemoryAnalyzer: config = AutoConfig.from_pretrained(model_name) torch_dtype = _get_dtype(config, dtype) - for model_class, memory_analyzer in _MEMORY_ANALYZERS.items(): - if model_class in model_name: - return memory_analyzer( - model_name, block_size, torch_dtype, gpu_memory, cpu_memory, - tensor_parallel_size) - raise ValueError(f'Unsupported model name: {model_name}') + memory_analyzer = _get_memory_analyzer(config) + return memory_analyzer( + model_name, block_size, torch_dtype, gpu_memory, cpu_memory, + tensor_parallel_size)