[Bugfix] Fix dummy weight for fp8 (#4916)
Allow dummy load format for fp8, torch.uniform_ doesn't support FP8 at the moment Co-authored-by: Mor Zusman <morz@ai21.com>
This commit is contained in:
parent
943e72ca56
commit
f0eecee610
@ -369,4 +369,11 @@ def initialize_dummy_weights(
|
||||
"""
|
||||
for param in model.state_dict().values():
|
||||
if torch.is_floating_point(param):
|
||||
param.data.uniform_(low, high)
|
||||
if torch.finfo(param.data.dtype).bits < 16:
|
||||
# uniform_ doesn't support < 16-bit datatypes (FP8)
|
||||
dtype = param.data.dtype
|
||||
tmp_param = param.data.to(torch.float16)
|
||||
tmp_param = tmp_param.uniform_(low, high).to(dtype)
|
||||
param.data.copy_(tmp_param)
|
||||
else:
|
||||
param.uniform_(low, high)
|
||||
|
Loading…
x
Reference in New Issue
Block a user