[TPU] Fix dummy loading OOM (#16372)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
a564797151
commit
1621b25288
@ -658,8 +658,21 @@ def initialize_dummy_weights(
|
||||
for param in model.state_dict().values():
|
||||
if torch.is_floating_point(param):
|
||||
if current_platform.is_tpu():
|
||||
# XLA device does not support torch.Generator()
|
||||
param.uniform_(low, high)
|
||||
generator = torch.Generator(device="cpu")
|
||||
generator.manual_seed(seed)
|
||||
# Note: The param.uniform_ function cannot be used in this
|
||||
# context because it demands more TPU HBM than directly copying
|
||||
# from a CPU tensor.
|
||||
# Note: We avoid using torch.rank_like as it doesn't currently
|
||||
# support the generator argument.
|
||||
param.copy_((high - low) *
|
||||
torch.rand(*param.shape,
|
||||
generator=generator,
|
||||
dtype=param.dtype,
|
||||
layout=param.layout,
|
||||
requires_grad=param.requires_grad,
|
||||
device="cpu") + low)
|
||||
torch._sync(param)
|
||||
continue
|
||||
|
||||
generator = torch.Generator(device=param.data.device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user