Enhance model loader (#83)
This commit is contained in:
parent
7c041ab578
commit
add055e151
@ -12,8 +12,8 @@ from cacheflow.core.scheduler import Scheduler
|
|||||||
from cacheflow.frontend.simple_frontend import SimpleFrontend
|
from cacheflow.frontend.simple_frontend import SimpleFrontend
|
||||||
from cacheflow.logger import init_logger
|
from cacheflow.logger import init_logger
|
||||||
from cacheflow.model_executor import get_memory_analyzer
|
from cacheflow.model_executor import get_memory_analyzer
|
||||||
from cacheflow.sequence import SequenceGroup
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
|
from cacheflow.sequence import SequenceGroup
|
||||||
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
||||||
from cacheflow.worker.controller import Controller, DeviceID
|
from cacheflow.worker.controller import Controller, DeviceID
|
||||||
|
|
||||||
|
@ -14,32 +14,51 @@ from cacheflow.model_executor.utils import get_torch_dtype
|
|||||||
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
|
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
|
||||||
|
|
||||||
|
|
||||||
_MODELS = {
|
# TODO(woosuk): Lazy-load the model classes.
|
||||||
'gpt2': GPT2LMHeadModel,
|
_MODEL_REGISTRY = {
|
||||||
'llama': LlamaForCausalLM,
|
"GPT2LMHeadModel": GPT2LMHeadModel,
|
||||||
'opt': OPTForCausalLM,
|
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
|
||||||
'stablelm': GPTNeoXForCausalLM,
|
"LlamaForCausalLM": LlamaForCausalLM,
|
||||||
'pythia': GPTNeoXForCausalLM,
|
"OPTForCausalLM": OPTForCausalLM,
|
||||||
'dolly-v2': GPTNeoXForCausalLM,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_MEMORY_ANALYZERS = {
|
_MEMORY_ANALYZERS = {
|
||||||
'gpt2': GPT2MemoryAnalyzer,
|
"GPT2LMHeadModel": GPT2MemoryAnalyzer,
|
||||||
'llama': LlamaMemoryAnalyzer,
|
"GPTNeoXForCausalLM": GPTNeoXMemoryAnalyzer,
|
||||||
'opt': OPTMemoryAnalyzer,
|
"LlamaForCausalLM": LlamaMemoryAnalyzer,
|
||||||
'stablelm': GPTNeoXMemoryAnalyzer,
|
"OPTForCausalLM": OPTMemoryAnalyzer,
|
||||||
'pythia': GPTNeoXMemoryAnalyzer,
|
|
||||||
'dolly-v2': GPTNeoXMemoryAnalyzer,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
|
||||||
|
architectures = getattr(config, "architectures", [])
|
||||||
|
for arch in architectures:
|
||||||
|
if arch in _MODEL_REGISTRY:
|
||||||
|
return _MODEL_REGISTRY[arch]
|
||||||
|
raise ValueError(
|
||||||
|
f"Model architectures {architectures} are not supported for now. "
|
||||||
|
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_memory_analyzer(config: PretrainedConfig) -> CacheFlowMemoryAnalyzer:
|
||||||
|
architectures = getattr(config, "architectures", [])
|
||||||
|
for arch in architectures:
|
||||||
|
if arch in _MEMORY_ANALYZERS:
|
||||||
|
return _MEMORY_ANALYZERS[arch]
|
||||||
|
raise ValueError(
|
||||||
|
f"Model architectures {architectures} are not supported for now. "
|
||||||
|
f"Supported architectures: {list(_MEMORY_ANALYZERS.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
|
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
|
||||||
# NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct
|
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||||
# because config.torch_dtype can be None.
|
# because config.torch_dtype can be None.
|
||||||
config_dtype = getattr(config, 'torch_dtype', None)
|
config_dtype = getattr(config, "torch_dtype", None)
|
||||||
if config_dtype is None:
|
if config_dtype is None:
|
||||||
config_dtype = torch.float32
|
config_dtype = torch.float32
|
||||||
if dtype == 'default':
|
if dtype == "default":
|
||||||
if config_dtype == torch.float32:
|
if config_dtype == torch.float32:
|
||||||
# Following the common practice, we use float16 for float32 models.
|
# Following the common practice, we use float16 for float32 models.
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
@ -51,7 +70,7 @@ def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
|
|||||||
# TODO(woosuk): Allow using float16 for bfloat16 models and
|
# TODO(woosuk): Allow using float16 for bfloat16 models and
|
||||||
# vice versa. Print a warning message and continue.
|
# vice versa. Print a warning message and continue.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Cannot use {torch_dtype} for {config_dtype} model.')
|
f"Cannot use {torch_dtype} for {config_dtype} model.")
|
||||||
return torch_dtype
|
return torch_dtype
|
||||||
|
|
||||||
|
|
||||||
@ -65,24 +84,21 @@ def get_model(
|
|||||||
config = AutoConfig.from_pretrained(model_name)
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
torch_dtype = _get_dtype(config, dtype)
|
torch_dtype = _get_dtype(config, dtype)
|
||||||
torch.set_default_dtype(torch_dtype)
|
torch.set_default_dtype(torch_dtype)
|
||||||
for model_class_name, model_class in _MODELS.items():
|
model_class = _get_model_architecture(config)
|
||||||
if model_class_name in model_name:
|
|
||||||
if use_dummy_weights:
|
# Create a model instance.
|
||||||
# Create a model instance.
|
# The weights will be initialized as empty tensors.
|
||||||
# The weights will be initialized as empty tensors.
|
model = model_class(config)
|
||||||
model = model_class(config)
|
if use_dummy_weights:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
# NOTE(woosuk): For precise performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
initialize_dummy_weights(model)
|
initialize_dummy_weights(model)
|
||||||
else:
|
else:
|
||||||
# Create a model instance.
|
# Load the weights from the cached or downloaded files.
|
||||||
model = model_class(config)
|
model.load_weights(model_name, cache_dir, use_np_cache)
|
||||||
# Load the weights from the cached or downloaded files.
|
model = model.cuda()
|
||||||
model.load_weights(model_name, cache_dir, use_np_cache)
|
return model.eval(), torch_dtype
|
||||||
model = model.cuda()
|
|
||||||
return model.eval(), torch_dtype
|
|
||||||
raise ValueError(f'Unsupported model name: {model_name}')
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_analyzer(
|
def get_memory_analyzer(
|
||||||
@ -95,9 +111,7 @@ def get_memory_analyzer(
|
|||||||
) -> CacheFlowMemoryAnalyzer:
|
) -> CacheFlowMemoryAnalyzer:
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
torch_dtype = _get_dtype(config, dtype)
|
torch_dtype = _get_dtype(config, dtype)
|
||||||
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
|
memory_analyzer = _get_memory_analyzer(config)
|
||||||
if model_class in model_name:
|
return memory_analyzer(
|
||||||
return memory_analyzer(
|
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
|
||||||
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
|
tensor_parallel_size)
|
||||||
tensor_parallel_size)
|
|
||||||
raise ValueError(f'Unsupported model name: {model_name}')
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user