2023-05-09 15:30:12 -07:00
|
|
|
import filelock
|
2023-05-03 15:32:04 +08:00
|
|
|
import glob
|
|
|
|
import json
|
2023-05-09 15:30:12 -07:00
|
|
|
import os
|
|
|
|
from typing import Iterator, List, Optional, Tuple
|
2023-03-11 23:23:14 -08:00
|
|
|
|
2023-05-09 15:30:12 -07:00
|
|
|
from huggingface_hub import snapshot_download
|
2023-05-03 15:32:04 +08:00
|
|
|
import numpy as np
|
2023-03-11 23:23:14 -08:00
|
|
|
import torch
|
2023-05-03 15:32:04 +08:00
|
|
|
from tqdm.auto import tqdm
|
2023-03-11 23:23:14 -08:00
|
|
|
|
2023-05-03 15:32:04 +08:00
|
|
|
|
|
|
|
class Disabledtqdm(tqdm):
|
2023-05-09 15:30:12 -07:00
|
|
|
|
2023-05-03 15:32:04 +08:00
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs, disable=True)
|
|
|
|
|
|
|
|
|
2023-05-09 15:30:12 -07:00
|
|
|
def hf_model_weights_iterator(
|
|
|
|
model_name_or_path: str,
|
|
|
|
cache_dir: Optional[str] = None,
|
|
|
|
use_np_cache: bool = False,
|
|
|
|
) -> Iterator[Tuple[str, torch.Tensor]]:
|
2023-05-03 15:32:04 +08:00
|
|
|
# Prepare file lock directory to prevent multiple processes from
|
|
|
|
# downloading the same model weights at the same time.
|
|
|
|
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
|
|
|
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
|
|
|
|
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
|
|
|
|
|
|
|
|
# Download model weights from huggingface.
|
|
|
|
is_local = os.path.isdir(model_name_or_path)
|
|
|
|
if not is_local:
|
|
|
|
with lock:
|
|
|
|
hf_folder = snapshot_download(model_name_or_path,
|
|
|
|
allow_patterns="*.bin",
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
tqdm_class=Disabledtqdm)
|
|
|
|
else:
|
|
|
|
hf_folder = model_name_or_path
|
|
|
|
|
|
|
|
hf_bin_files = glob.glob(os.path.join(hf_folder, "*.bin"))
|
|
|
|
|
|
|
|
if use_np_cache:
|
|
|
|
# Convert the model weights from torch tensors to numpy arrays for
|
|
|
|
# faster loading.
|
|
|
|
np_folder = os.path.join(hf_folder, 'np')
|
|
|
|
os.makedirs(np_folder, exist_ok=True)
|
|
|
|
weight_names_file = os.path.join(np_folder, 'weight_names.json')
|
|
|
|
with lock:
|
|
|
|
if not os.path.exists(weight_names_file):
|
|
|
|
weight_names = []
|
|
|
|
for bin_file in hf_bin_files:
|
|
|
|
state = torch.load(bin_file, map_location="cpu")
|
|
|
|
for name, param in state.items():
|
|
|
|
param_path = os.path.join(np_folder, name)
|
|
|
|
with open(param_path, "wb") as f:
|
|
|
|
np.save(f, param.cpu().detach().numpy())
|
|
|
|
weight_names.append(name)
|
|
|
|
with open(weight_names_file, 'w') as f:
|
|
|
|
json.dump(weight_names, f)
|
|
|
|
|
|
|
|
with open(weight_names_file, 'r') as f:
|
|
|
|
weight_names = json.load(f)
|
|
|
|
|
|
|
|
for name in weight_names:
|
|
|
|
param_path = os.path.join(np_folder, name)
|
|
|
|
with open(param_path, "rb") as f:
|
|
|
|
param = np.load(f)
|
|
|
|
yield name, torch.from_numpy(param)
|
|
|
|
else:
|
|
|
|
for bin_file in hf_bin_files:
|
|
|
|
state = torch.load(bin_file, map_location="cpu")
|
|
|
|
for name, param in state.items():
|
|
|
|
yield name, param
|
|
|
|
|
|
|
|
|
2023-05-09 15:30:12 -07:00
|
|
|
def load_tensor_parallel_weights(
|
|
|
|
param: torch.Tensor,
|
|
|
|
loaded_weight: torch.Tensor,
|
|
|
|
param_name: str,
|
|
|
|
column_parallel_weight_names: List[str],
|
|
|
|
row_parallel_weight_names: List[str],
|
|
|
|
tensor_model_parallel_rank: int,
|
|
|
|
) -> None:
|
2023-05-03 15:32:04 +08:00
|
|
|
for p in column_parallel_weight_names:
|
|
|
|
if p in param_name:
|
|
|
|
shard_size = param.shape[0]
|
|
|
|
loaded_weight = loaded_weight[
|
|
|
|
shard_size * tensor_model_parallel_rank
|
|
|
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
|
|
|
break
|
|
|
|
for p in row_parallel_weight_names:
|
|
|
|
if p in param_name:
|
|
|
|
shard_size = param.shape[1]
|
|
|
|
loaded_weight = loaded_weight[
|
|
|
|
:,
|
|
|
|
shard_size * tensor_model_parallel_rank
|
|
|
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
|
|
|
break
|
|
|
|
assert param.shape == loaded_weight.shape
|
|
|
|
param.data.copy_(loaded_weight)
|
2023-05-09 15:30:12 -07:00
|
|
|
|
|
|
|
|
|
|
|
def initialize_dummy_weights(
|
|
|
|
model: torch.nn.Module,
|
|
|
|
low: float = -1e-3,
|
|
|
|
high: float = 1e-3,
|
|
|
|
) -> None:
|
|
|
|
for param in model.state_dict().values():
|
|
|
|
param.data.uniform_(low, high)
|