[Bugfix] Fix dummy weight for fp8 (#4916)

Allow dummy load format for fp8,
torch.uniform_ doesn't support FP8 at the moment

Co-authored-by: Mor Zusman <morz@ai21.com>
This commit is contained in:
Mor Zusman 2024-05-20 21:44:25 +03:00 committed by GitHub
parent 943e72ca56
commit f0eecee610
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -369,4 +369,11 @@ def initialize_dummy_weights(
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
param.data.uniform_(low, high)
if torch.finfo(param.data.dtype).bits < 16:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype = param.data.dtype
tmp_param = param.data.to(torch.float16)
tmp_param = tmp_param.uniform_(low, high).to(dtype)
param.data.copy_(tmp_param)
else:
param.uniform_(low, high)