Enhance model loader (#83)
This commit is contained in:
parent
7c041ab578
commit
add055e151
@ -12,8 +12,8 @@ from cacheflow.core.scheduler import Scheduler
|
||||
from cacheflow.frontend.simple_frontend import SimpleFrontend
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.model_executor import get_memory_analyzer
|
||||
from cacheflow.sequence import SequenceGroup
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import SequenceGroup
|
||||
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
||||
from cacheflow.worker.controller import Controller, DeviceID
|
||||
|
||||
|
@ -14,32 +14,51 @@ from cacheflow.model_executor.utils import get_torch_dtype
|
||||
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
|
||||
|
||||
|
||||
_MODELS = {
|
||||
'gpt2': GPT2LMHeadModel,
|
||||
'llama': LlamaForCausalLM,
|
||||
'opt': OPTForCausalLM,
|
||||
'stablelm': GPTNeoXForCausalLM,
|
||||
'pythia': GPTNeoXForCausalLM,
|
||||
'dolly-v2': GPTNeoXForCausalLM,
|
||||
# TODO(woosuk): Lazy-load the model classes.
|
||||
_MODEL_REGISTRY = {
|
||||
"GPT2LMHeadModel": GPT2LMHeadModel,
|
||||
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"OPTForCausalLM": OPTForCausalLM,
|
||||
}
|
||||
|
||||
_MEMORY_ANALYZERS = {
|
||||
'gpt2': GPT2MemoryAnalyzer,
|
||||
'llama': LlamaMemoryAnalyzer,
|
||||
'opt': OPTMemoryAnalyzer,
|
||||
'stablelm': GPTNeoXMemoryAnalyzer,
|
||||
'pythia': GPTNeoXMemoryAnalyzer,
|
||||
'dolly-v2': GPTNeoXMemoryAnalyzer,
|
||||
"GPT2LMHeadModel": GPT2MemoryAnalyzer,
|
||||
"GPTNeoXForCausalLM": GPTNeoXMemoryAnalyzer,
|
||||
"LlamaForCausalLM": LlamaMemoryAnalyzer,
|
||||
"OPTForCausalLM": OPTMemoryAnalyzer,
|
||||
}
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _MODEL_REGISTRY:
|
||||
return _MODEL_REGISTRY[arch]
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
|
||||
)
|
||||
|
||||
|
||||
def _get_memory_analyzer(config: PretrainedConfig) -> CacheFlowMemoryAnalyzer:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _MEMORY_ANALYZERS:
|
||||
return _MEMORY_ANALYZERS[arch]
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {list(_MEMORY_ANALYZERS.keys())}"
|
||||
)
|
||||
|
||||
|
||||
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
|
||||
# NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, 'torch_dtype', None)
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
if dtype == 'default':
|
||||
if dtype == "default":
|
||||
if config_dtype == torch.float32:
|
||||
# Following the common practice, we use float16 for float32 models.
|
||||
torch_dtype = torch.float16
|
||||
@ -51,7 +70,7 @@ def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
|
||||
# TODO(woosuk): Allow using float16 for bfloat16 models and
|
||||
# vice versa. Print a warning message and continue.
|
||||
raise ValueError(
|
||||
f'Cannot use {torch_dtype} for {config_dtype} model.')
|
||||
f"Cannot use {torch_dtype} for {config_dtype} model.")
|
||||
return torch_dtype
|
||||
|
||||
|
||||
@ -65,24 +84,21 @@ def get_model(
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
torch_dtype = _get_dtype(config, dtype)
|
||||
torch.set_default_dtype(torch_dtype)
|
||||
for model_class_name, model_class in _MODELS.items():
|
||||
if model_class_name in model_name:
|
||||
if use_dummy_weights:
|
||||
# Create a model instance.
|
||||
# The weights will be initialized as empty tensors.
|
||||
model = model_class(config)
|
||||
model = model.cuda()
|
||||
# NOTE(woosuk): For precise performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
else:
|
||||
# Create a model instance.
|
||||
model = model_class(config)
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(model_name, cache_dir, use_np_cache)
|
||||
model = model.cuda()
|
||||
return model.eval(), torch_dtype
|
||||
raise ValueError(f'Unsupported model name: {model_name}')
|
||||
model_class = _get_model_architecture(config)
|
||||
|
||||
# Create a model instance.
|
||||
# The weights will be initialized as empty tensors.
|
||||
model = model_class(config)
|
||||
if use_dummy_weights:
|
||||
model = model.cuda()
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
else:
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(model_name, cache_dir, use_np_cache)
|
||||
model = model.cuda()
|
||||
return model.eval(), torch_dtype
|
||||
|
||||
|
||||
def get_memory_analyzer(
|
||||
@ -95,9 +111,7 @@ def get_memory_analyzer(
|
||||
) -> CacheFlowMemoryAnalyzer:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
torch_dtype = _get_dtype(config, dtype)
|
||||
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
|
||||
if model_class in model_name:
|
||||
return memory_analyzer(
|
||||
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
|
||||
tensor_parallel_size)
|
||||
raise ValueError(f'Unsupported model name: {model_name}')
|
||||
memory_analyzer = _get_memory_analyzer(config)
|
||||
return memory_analyzer(
|
||||
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
|
||||
tensor_parallel_size)
|
||||
|
Loading…
x
Reference in New Issue
Block a user