vllm/cacheflow/model_executor/weight_utils.py
2023-05-09 15:30:12 -07:00

111 lines
3.8 KiB
Python

import filelock
import glob
import json
import os
from typing import Iterator, List, Optional, Tuple
from huggingface_hub import snapshot_download
import numpy as np
import torch
from tqdm.auto import tqdm
class Disabledtqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
) -> Iterator[Tuple[str, torch.Tensor]]:
# Prepare file lock directory to prevent multiple processes from
# downloading the same model weights at the same time.
lock_dir = cache_dir if cache_dir is not None else "/tmp"
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
if not is_local:
with lock:
hf_folder = snapshot_download(model_name_or_path,
allow_patterns="*.bin",
cache_dir=cache_dir,
tqdm_class=Disabledtqdm)
else:
hf_folder = model_name_or_path
hf_bin_files = glob.glob(os.path.join(hf_folder, "*.bin"))
if use_np_cache:
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, 'np')
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, 'weight_names.json')
with lock:
if not os.path.exists(weight_names_file):
weight_names = []
for bin_file in hf_bin_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
weight_names.append(name)
with open(weight_names_file, 'w') as f:
json.dump(weight_names, f)
with open(weight_names_file, 'r') as f:
weight_names = json.load(f)
for name in weight_names:
param_path = os.path.join(np_folder, name)
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
else:
for bin_file in hf_bin_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
yield name, param
def load_tensor_parallel_weights(
param: torch.Tensor,
loaded_weight: torch.Tensor,
param_name: str,
column_parallel_weight_names: List[str],
row_parallel_weight_names: List[str],
tensor_model_parallel_rank: int,
) -> None:
for p in column_parallel_weight_names:
if p in param_name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in row_parallel_weight_names:
if p in param_name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
) -> None:
for param in model.state_dict().values():
param.data.uniform_(low, high)