[Minor] Fix a dtype bug (#79)

This commit is contained in:
Woosuk Kwon 2023-05-06 02:12:12 -07:00 committed by GitHub
parent c9d5b6d4a8
commit c84e924287
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -37,7 +37,11 @@ _MEMORY_ANALYZERS = {
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32)
# NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, 'torch_dtype', None)
if config_dtype is None:
config_dtype = torch.float32
if dtype == 'default':
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.