vllm/vllm/model_executor/weight_utils.py

119 lines
4.2 KiB
Python
Raw Normal View History

"""Utilities for downloading and initializing model weights."""
2023-05-09 15:30:12 -07:00
import filelock
import glob
import json
2023-05-09 15:30:12 -07:00
import os
from typing import Iterator, List, Optional, Tuple
2023-05-09 15:30:12 -07:00
from huggingface_hub import snapshot_download
import numpy as np
import torch
from tqdm.auto import tqdm
class Disabledtqdm(tqdm):
2023-05-09 15:30:12 -07:00
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
2023-05-09 15:30:12 -07:00
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
) -> Iterator[Tuple[str, torch.Tensor]]:
# Prepare file lock directory to prevent multiple processes from
# downloading the same model weights at the same time.
lock_dir = cache_dir if cache_dir is not None else "/tmp"
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
if not is_local:
with lock:
hf_folder = snapshot_download(model_name_or_path,
allow_patterns="*.bin",
cache_dir=cache_dir,
tqdm_class=Disabledtqdm)
else:
hf_folder = model_name_or_path
hf_bin_files = glob.glob(os.path.join(hf_folder, "*.bin"))
if use_np_cache:
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, 'np')
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, 'weight_names.json')
with lock:
if not os.path.exists(weight_names_file):
weight_names = []
for bin_file in hf_bin_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
weight_names.append(name)
with open(weight_names_file, 'w') as f:
json.dump(weight_names, f)
with open(weight_names_file, 'r') as f:
weight_names = json.load(f)
for name in weight_names:
param_path = os.path.join(np_folder, name)
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
else:
for bin_file in hf_bin_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
yield name, param
2023-05-09 15:30:12 -07:00
def load_tensor_parallel_weights(
param: torch.Tensor,
loaded_weight: torch.Tensor,
param_name: str,
column_parallel_weight_names: List[str],
row_parallel_weight_names: List[str],
tensor_model_parallel_rank: int,
) -> None:
for p in column_parallel_weight_names:
if p in param_name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in row_parallel_weight_names:
if p in param_name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
2023-05-09 15:30:12 -07:00
def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
) -> None:
"""Initialize model weights with random values.
The model weights must be randomly initialized for accurate performance
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
"""
2023-05-09 15:30:12 -07:00
for param in model.state_dict().values():
param.data.uniform_(low, high)