[TPU] Fix dummy loading OOM (#16372)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao 2025-04-09 21:06:16 -07:00 committed by GitHub
parent a564797151
commit 1621b25288
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -658,8 +658,21 @@ def initialize_dummy_weights(
for param in model.state_dict().values():
if torch.is_floating_point(param):
if current_platform.is_tpu():
# XLA device does not support torch.Generator()
param.uniform_(low, high)
generator = torch.Generator(device="cpu")
generator.manual_seed(seed)
# Note: The param.uniform_ function cannot be used in this
# context because it demands more TPU HBM than directly copying
# from a CPU tensor.
# Note: We avoid using torch.rank_like as it doesn't currently
# support the generator argument.
param.copy_((high - low) *
torch.rand(*param.shape,
generator=generator,
dtype=param.dtype,
layout=param.layout,
requires_grad=param.requires_grad,
device="cpu") + low)
torch._sync(param)
continue
generator = torch.Generator(device=param.data.device)