[Minor] Fix a dtype bug (#79)
This commit is contained in:
parent
c9d5b6d4a8
commit
c84e924287
@ -37,7 +37,11 @@ _MEMORY_ANALYZERS = {
|
||||
|
||||
|
||||
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
|
||||
config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32)
|
||||
# NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, 'torch_dtype', None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
if dtype == 'default':
|
||||
if config_dtype == torch.float32:
|
||||
# Following the common practice, we use float16 for float32 models.
|
||||
|
Loading…
x
Reference in New Issue
Block a user