Enhance model loader (#83)

This commit is contained in:
Woosuk Kwon 2023-05-09 15:46:42 -07:00 committed by GitHub
parent 7c041ab578
commit add055e151
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 42 deletions

View File

@ -12,8 +12,8 @@ from cacheflow.core.scheduler import Scheduler
from cacheflow.frontend.simple_frontend import SimpleFrontend from cacheflow.frontend.simple_frontend import SimpleFrontend
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
from cacheflow.model_executor import get_memory_analyzer from cacheflow.model_executor import get_memory_analyzer
from cacheflow.sequence import SequenceGroup
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroup
from cacheflow.utils import get_gpu_memory, get_cpu_memory from cacheflow.utils import get_gpu_memory, get_cpu_memory
from cacheflow.worker.controller import Controller, DeviceID from cacheflow.worker.controller import Controller, DeviceID

View File

@ -14,32 +14,51 @@ from cacheflow.model_executor.utils import get_torch_dtype
from cacheflow.model_executor.weight_utils import initialize_dummy_weights from cacheflow.model_executor.weight_utils import initialize_dummy_weights
_MODELS = { # TODO(woosuk): Lazy-load the model classes.
'gpt2': GPT2LMHeadModel, _MODEL_REGISTRY = {
'llama': LlamaForCausalLM, "GPT2LMHeadModel": GPT2LMHeadModel,
'opt': OPTForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
'stablelm': GPTNeoXForCausalLM, "LlamaForCausalLM": LlamaForCausalLM,
'pythia': GPTNeoXForCausalLM, "OPTForCausalLM": OPTForCausalLM,
'dolly-v2': GPTNeoXForCausalLM,
} }
_MEMORY_ANALYZERS = { _MEMORY_ANALYZERS = {
'gpt2': GPT2MemoryAnalyzer, "GPT2LMHeadModel": GPT2MemoryAnalyzer,
'llama': LlamaMemoryAnalyzer, "GPTNeoXForCausalLM": GPTNeoXMemoryAnalyzer,
'opt': OPTMemoryAnalyzer, "LlamaForCausalLM": LlamaMemoryAnalyzer,
'stablelm': GPTNeoXMemoryAnalyzer, "OPTForCausalLM": OPTMemoryAnalyzer,
'pythia': GPTNeoXMemoryAnalyzer,
'dolly-v2': GPTNeoXMemoryAnalyzer,
} }
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MODEL_REGISTRY:
return _MODEL_REGISTRY[arch]
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
)
def _get_memory_analyzer(config: PretrainedConfig) -> CacheFlowMemoryAnalyzer:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MEMORY_ANALYZERS:
return _MEMORY_ANALYZERS[arch]
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MEMORY_ANALYZERS.keys())}"
)
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype: def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
# NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None. # because config.torch_dtype can be None.
config_dtype = getattr(config, 'torch_dtype', None) config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None: if config_dtype is None:
config_dtype = torch.float32 config_dtype = torch.float32
if dtype == 'default': if dtype == "default":
if config_dtype == torch.float32: if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models. # Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16 torch_dtype = torch.float16
@ -51,7 +70,7 @@ def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
# TODO(woosuk): Allow using float16 for bfloat16 models and # TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue. # vice versa. Print a warning message and continue.
raise ValueError( raise ValueError(
f'Cannot use {torch_dtype} for {config_dtype} model.') f"Cannot use {torch_dtype} for {config_dtype} model.")
return torch_dtype return torch_dtype
@ -65,24 +84,21 @@ def get_model(
config = AutoConfig.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype) torch_dtype = _get_dtype(config, dtype)
torch.set_default_dtype(torch_dtype) torch.set_default_dtype(torch_dtype)
for model_class_name, model_class in _MODELS.items(): model_class = _get_model_architecture(config)
if model_class_name in model_name:
if use_dummy_weights: # Create a model instance.
# Create a model instance. # The weights will be initialized as empty tensors.
# The weights will be initialized as empty tensors. model = model_class(config)
model = model_class(config) if use_dummy_weights:
model = model.cuda() model = model.cuda()
# NOTE(woosuk): For precise performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
initialize_dummy_weights(model) initialize_dummy_weights(model)
else: else:
# Create a model instance. # Load the weights from the cached or downloaded files.
model = model_class(config) model.load_weights(model_name, cache_dir, use_np_cache)
# Load the weights from the cached or downloaded files. model = model.cuda()
model.load_weights(model_name, cache_dir, use_np_cache) return model.eval(), torch_dtype
model = model.cuda()
return model.eval(), torch_dtype
raise ValueError(f'Unsupported model name: {model_name}')
def get_memory_analyzer( def get_memory_analyzer(
@ -95,9 +111,7 @@ def get_memory_analyzer(
) -> CacheFlowMemoryAnalyzer: ) -> CacheFlowMemoryAnalyzer:
config = AutoConfig.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype) torch_dtype = _get_dtype(config, dtype)
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items(): memory_analyzer = _get_memory_analyzer(config)
if model_class in model_name: return memory_analyzer(
return memory_analyzer( model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
model_name, block_size, torch_dtype, gpu_memory, cpu_memory, tensor_parallel_size)
tensor_parallel_size)
raise ValueError(f'Unsupported model name: {model_name}')