2023-05-14 22:32:38 -07:00
|
|
|
"""Utilities for downloading and initializing model weights."""
|
2023-05-09 15:30:12 -07:00
|
|
|
import filelock
|
2023-05-03 15:32:04 +08:00
|
|
|
import glob
|
|
|
|
import json
|
2023-05-09 15:30:12 -07:00
|
|
|
import os
|
2023-08-30 16:00:13 +08:00
|
|
|
from collections import defaultdict
|
2023-09-16 00:03:37 -07:00
|
|
|
from typing import Any, Iterator, List, Optional, Tuple
|
2023-03-11 23:23:14 -08:00
|
|
|
|
2023-05-09 15:30:12 -07:00
|
|
|
from huggingface_hub import snapshot_download
|
2023-08-30 16:00:13 +08:00
|
|
|
from safetensors.torch import load_file, save_file, safe_open
|
2023-05-03 15:32:04 +08:00
|
|
|
import numpy as np
|
2023-03-11 23:23:14 -08:00
|
|
|
import torch
|
2023-05-03 15:32:04 +08:00
|
|
|
from tqdm.auto import tqdm
|
2023-03-11 23:23:14 -08:00
|
|
|
|
2023-08-30 16:00:13 +08:00
|
|
|
from vllm.logger import init_logger
|
2023-09-16 00:03:37 -07:00
|
|
|
from vllm.model_executor.quantization_utils import get_quant_class
|
|
|
|
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
2023-08-30 16:00:13 +08:00
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
2023-05-03 15:32:04 +08:00
|
|
|
|
|
|
|
class Disabledtqdm(tqdm):
|
2023-05-09 15:30:12 -07:00
|
|
|
|
2023-05-03 15:32:04 +08:00
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs, disable=True)
|
|
|
|
|
|
|
|
|
2023-08-30 16:00:13 +08:00
|
|
|
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
2023-05-03 15:32:04 +08:00
|
|
|
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
|
|
|
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
|
|
|
|
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
|
2023-08-30 16:00:13 +08:00
|
|
|
return lock
|
|
|
|
|
|
|
|
|
|
|
|
def _shared_pointers(tensors):
|
|
|
|
ptrs = defaultdict(list)
|
|
|
|
for k, v in tensors.items():
|
|
|
|
ptrs[v.data_ptr()].append(k)
|
|
|
|
failing = []
|
|
|
|
for _, names in ptrs.items():
|
|
|
|
if len(names) > 1:
|
|
|
|
failing.append(names)
|
|
|
|
return failing
|
|
|
|
|
|
|
|
|
|
|
|
def convert_bin_to_safetensor_file(
|
|
|
|
pt_filename: str,
|
|
|
|
sf_filename: str,
|
2023-09-16 00:03:37 -07:00
|
|
|
) -> None:
|
2023-08-30 16:00:13 +08:00
|
|
|
loaded = torch.load(pt_filename, map_location="cpu")
|
|
|
|
if "state_dict" in loaded:
|
|
|
|
loaded = loaded["state_dict"]
|
|
|
|
shared = _shared_pointers(loaded)
|
|
|
|
for shared_weights in shared:
|
|
|
|
for name in shared_weights[1:]:
|
|
|
|
loaded.pop(name)
|
|
|
|
|
|
|
|
# For tensors to be contiguous
|
|
|
|
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
|
|
|
|
|
|
|
dirname = os.path.dirname(sf_filename)
|
|
|
|
os.makedirs(dirname, exist_ok=True)
|
|
|
|
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
|
|
|
|
|
|
|
# check file size
|
|
|
|
sf_size = os.stat(sf_filename).st_size
|
|
|
|
pt_size = os.stat(pt_filename).st_size
|
|
|
|
if (sf_size - pt_size) / pt_size > 0.01:
|
|
|
|
raise RuntimeError(f"""The file size different is more than 1%:
|
|
|
|
- {sf_filename}: {sf_size}
|
|
|
|
- {pt_filename}: {pt_size}
|
|
|
|
""")
|
|
|
|
|
|
|
|
# check if the tensors are the same
|
|
|
|
reloaded = load_file(sf_filename)
|
|
|
|
for k in loaded:
|
|
|
|
pt_tensor = loaded[k]
|
|
|
|
sf_tensor = reloaded[k]
|
|
|
|
if not torch.equal(pt_tensor, sf_tensor):
|
|
|
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
|
|
|
|
|
|
|
|
2023-09-16 00:03:37 -07:00
|
|
|
# TODO(woosuk): Move this to other place.
|
|
|
|
def get_quant_config(
|
|
|
|
quantization: str,
|
|
|
|
model_name_or_path: str,
|
|
|
|
cache_dir: Optional[str] = None,
|
|
|
|
) -> QuantizationConfig:
|
|
|
|
is_local = os.path.isdir(model_name_or_path)
|
|
|
|
if not is_local:
|
|
|
|
# Download the config files.
|
|
|
|
with get_lock(model_name_or_path, cache_dir):
|
|
|
|
hf_folder = snapshot_download(model_name_or_path,
|
|
|
|
allow_patterns="*.json",
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
tqdm_class=Disabledtqdm)
|
|
|
|
else:
|
|
|
|
hf_folder = model_name_or_path
|
|
|
|
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
|
|
|
|
|
|
|
quant_cls = get_quant_class(quantization)
|
|
|
|
quant_config_files = [
|
|
|
|
f for f in config_files if any(
|
|
|
|
f.endswith(x) for x in quant_cls.get_config_filenames())
|
|
|
|
]
|
|
|
|
if len(quant_config_files) == 0:
|
|
|
|
raise ValueError(f"Cannot find the config file for {quantization}")
|
|
|
|
if len(quant_config_files) > 1:
|
|
|
|
raise ValueError(f"Found multiple config files for {quantization}: "
|
|
|
|
f"{quant_config_files}")
|
|
|
|
|
|
|
|
quant_config_file = quant_config_files[0]
|
|
|
|
with open(quant_config_file, "r") as f:
|
|
|
|
config = json.load(f)
|
|
|
|
return quant_cls.from_config(config)
|
|
|
|
|
|
|
|
|
2023-08-30 16:00:13 +08:00
|
|
|
def prepare_hf_model_weights(
|
|
|
|
model_name_or_path: str,
|
|
|
|
cache_dir: Optional[str] = None,
|
2023-09-07 15:49:52 -07:00
|
|
|
use_safetensors: bool = False,
|
|
|
|
fall_back_to_pt: bool = True,
|
2023-09-14 06:20:02 +08:00
|
|
|
revision: Optional[str] = None,
|
2023-09-16 00:03:37 -07:00
|
|
|
) -> Tuple[str, List[str], bool]:
|
2023-05-03 15:32:04 +08:00
|
|
|
# Download model weights from huggingface.
|
|
|
|
is_local = os.path.isdir(model_name_or_path)
|
2023-09-16 00:03:37 -07:00
|
|
|
if use_safetensors:
|
|
|
|
allow_patterns = ["*.safetensors"]
|
|
|
|
else:
|
|
|
|
# Some quantized models use .pt files for storing the weights.
|
|
|
|
allow_patterns = ["*.bin", "*.pt"]
|
2023-05-03 15:32:04 +08:00
|
|
|
if not is_local:
|
2023-08-30 16:00:13 +08:00
|
|
|
# Use file lock to prevent multiple processes from
|
|
|
|
# downloading the same model weights at the same time.
|
|
|
|
with get_lock(model_name_or_path, cache_dir):
|
2023-05-03 15:32:04 +08:00
|
|
|
hf_folder = snapshot_download(model_name_or_path,
|
2023-08-30 16:00:13 +08:00
|
|
|
allow_patterns=allow_patterns,
|
2023-05-03 15:32:04 +08:00
|
|
|
cache_dir=cache_dir,
|
2023-09-14 06:20:02 +08:00
|
|
|
tqdm_class=Disabledtqdm,
|
|
|
|
revision=revision)
|
2023-05-03 15:32:04 +08:00
|
|
|
else:
|
|
|
|
hf_folder = model_name_or_path
|
2023-09-16 00:03:37 -07:00
|
|
|
hf_weights_files: List[str] = []
|
|
|
|
for pattern in allow_patterns:
|
|
|
|
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
2023-09-07 15:49:52 -07:00
|
|
|
if not use_safetensors:
|
2023-10-12 01:05:37 -07:00
|
|
|
# Exclude files that are not needed for inference.
|
|
|
|
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
|
|
|
blacklist = [
|
|
|
|
"training_args.bin",
|
|
|
|
"optimizer.bin",
|
|
|
|
"optimizer.pt",
|
|
|
|
"scheduler.pt",
|
|
|
|
"scaler.pt",
|
|
|
|
]
|
2023-08-30 16:00:13 +08:00
|
|
|
hf_weights_files = [
|
2023-10-12 01:05:37 -07:00
|
|
|
f for f in hf_weights_files
|
|
|
|
if not any(f.endswith(x) for x in blacklist)
|
2023-08-30 16:00:13 +08:00
|
|
|
]
|
|
|
|
|
2023-09-07 15:49:52 -07:00
|
|
|
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
|
2023-08-30 16:00:13 +08:00
|
|
|
return prepare_hf_model_weights(model_name_or_path,
|
|
|
|
cache_dir=cache_dir,
|
2023-09-07 15:49:52 -07:00
|
|
|
use_safetensors=False,
|
2023-09-14 06:20:02 +08:00
|
|
|
fall_back_to_pt=False,
|
|
|
|
revision=revision)
|
2023-09-07 15:49:52 -07:00
|
|
|
|
|
|
|
if len(hf_weights_files) == 0:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Cannot find any model weights with `{model_name_or_path}`")
|
|
|
|
|
|
|
|
return hf_folder, hf_weights_files, use_safetensors
|
2023-05-03 15:32:04 +08:00
|
|
|
|
2023-08-30 16:00:13 +08:00
|
|
|
|
|
|
|
def hf_model_weights_iterator(
|
|
|
|
model_name_or_path: str,
|
|
|
|
cache_dir: Optional[str] = None,
|
2023-09-07 15:49:52 -07:00
|
|
|
load_format: str = "auto",
|
2023-09-14 06:20:02 +08:00
|
|
|
revision: Optional[str] = None,
|
2023-08-30 16:00:13 +08:00
|
|
|
) -> Iterator[Tuple[str, torch.Tensor]]:
|
2023-09-07 15:49:52 -07:00
|
|
|
use_safetensors = False
|
|
|
|
use_np_cache = False
|
|
|
|
fall_back_to_pt = False
|
|
|
|
if load_format == "auto":
|
|
|
|
use_safetensors = True
|
|
|
|
fall_back_to_pt = True
|
|
|
|
elif load_format == "safetensors":
|
|
|
|
use_safetensors = True
|
|
|
|
elif load_format == "pt":
|
|
|
|
pass
|
|
|
|
elif load_format == "npcache":
|
|
|
|
use_np_cache = True
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unknown load_format: {load_format}")
|
|
|
|
|
|
|
|
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
|
|
|
|
model_name_or_path,
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
use_safetensors=use_safetensors,
|
2023-09-14 06:20:02 +08:00
|
|
|
fall_back_to_pt=fall_back_to_pt,
|
|
|
|
revision=revision)
|
2023-05-03 15:32:04 +08:00
|
|
|
|
|
|
|
if use_np_cache:
|
2023-08-30 16:00:13 +08:00
|
|
|
# Currently np_cache only support *.bin checkpoints
|
2023-09-07 15:49:52 -07:00
|
|
|
assert use_safetensors is False
|
2023-08-30 16:00:13 +08:00
|
|
|
|
2023-05-03 15:32:04 +08:00
|
|
|
# Convert the model weights from torch tensors to numpy arrays for
|
|
|
|
# faster loading.
|
2023-07-03 11:31:55 -07:00
|
|
|
np_folder = os.path.join(hf_folder, "np")
|
2023-05-03 15:32:04 +08:00
|
|
|
os.makedirs(np_folder, exist_ok=True)
|
2023-07-03 11:31:55 -07:00
|
|
|
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
2023-08-30 16:00:13 +08:00
|
|
|
# Use file lock to prevent multiple processes from
|
|
|
|
# dumping the same model weights to numpy at the same time.
|
|
|
|
with get_lock(model_name_or_path, cache_dir):
|
2023-05-03 15:32:04 +08:00
|
|
|
if not os.path.exists(weight_names_file):
|
|
|
|
weight_names = []
|
2023-08-30 16:00:13 +08:00
|
|
|
for bin_file in hf_weights_files:
|
2023-05-03 15:32:04 +08:00
|
|
|
state = torch.load(bin_file, map_location="cpu")
|
|
|
|
for name, param in state.items():
|
|
|
|
param_path = os.path.join(np_folder, name)
|
|
|
|
with open(param_path, "wb") as f:
|
|
|
|
np.save(f, param.cpu().detach().numpy())
|
|
|
|
weight_names.append(name)
|
2023-07-03 11:31:55 -07:00
|
|
|
with open(weight_names_file, "w") as f:
|
2023-05-03 15:32:04 +08:00
|
|
|
json.dump(weight_names, f)
|
|
|
|
|
2023-07-03 11:31:55 -07:00
|
|
|
with open(weight_names_file, "r") as f:
|
2023-05-03 15:32:04 +08:00
|
|
|
weight_names = json.load(f)
|
|
|
|
|
|
|
|
for name in weight_names:
|
|
|
|
param_path = os.path.join(np_folder, name)
|
|
|
|
with open(param_path, "rb") as f:
|
|
|
|
param = np.load(f)
|
|
|
|
yield name, torch.from_numpy(param)
|
2023-09-07 15:49:52 -07:00
|
|
|
elif use_safetensors:
|
2023-08-30 16:00:13 +08:00
|
|
|
for st_file in hf_weights_files:
|
|
|
|
with safe_open(st_file, framework="pt") as f:
|
|
|
|
for name in f.keys():
|
|
|
|
param = f.get_slice(name)
|
|
|
|
yield name, param
|
2023-05-03 15:32:04 +08:00
|
|
|
else:
|
2023-08-30 16:00:13 +08:00
|
|
|
for bin_file in hf_weights_files:
|
2023-05-03 15:32:04 +08:00
|
|
|
state = torch.load(bin_file, map_location="cpu")
|
|
|
|
for name, param in state.items():
|
|
|
|
yield name, param
|
2023-08-18 03:56:04 +08:00
|
|
|
del state
|
|
|
|
torch.cuda.empty_cache()
|
2023-05-03 15:32:04 +08:00
|
|
|
|
|
|
|
|
2023-09-07 15:49:52 -07:00
|
|
|
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
|
|
|
"""convert PySafeSlice object from safetensors to torch.Tensor
|
|
|
|
|
|
|
|
PySafeSlice object supports indexing, which is done before loading the
|
|
|
|
actual tensor and can reduce the amount of memory being read into the
|
|
|
|
memory. However, it does not support more advanced functionalities
|
|
|
|
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
|
|
|
tensor with these more complicated operators, we need to convert to
|
|
|
|
tensor first.
|
|
|
|
"""
|
|
|
|
if not isinstance(x, torch.Tensor):
|
|
|
|
x = x[:]
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
2023-08-30 16:00:13 +08:00
|
|
|
def load_padded_tensor_parallel_vocab(
|
|
|
|
param: torch.Tensor,
|
|
|
|
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
|
|
|
tensor_model_parallel_rank: int,
|
|
|
|
) -> None:
|
|
|
|
shard_size = param.shape[0]
|
|
|
|
start_idx = tensor_model_parallel_rank * shard_size
|
|
|
|
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
|
|
|
loaded_weight = loaded_weight[start_idx:end_idx]
|
2023-09-07 15:49:52 -07:00
|
|
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
2023-08-30 16:00:13 +08:00
|
|
|
param[:loaded_weight.shape[0]].copy_(loaded_weight)
|
|
|
|
|
|
|
|
|
2023-05-09 15:30:12 -07:00
|
|
|
def load_tensor_parallel_weights(
|
|
|
|
param: torch.Tensor,
|
2023-08-30 16:00:13 +08:00
|
|
|
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
2023-05-09 15:30:12 -07:00
|
|
|
param_name: str,
|
|
|
|
column_parallel_weight_names: List[str],
|
|
|
|
row_parallel_weight_names: List[str],
|
|
|
|
tensor_model_parallel_rank: int,
|
|
|
|
) -> None:
|
2023-05-03 15:32:04 +08:00
|
|
|
for p in column_parallel_weight_names:
|
|
|
|
if p in param_name:
|
|
|
|
shard_size = param.shape[0]
|
2023-07-03 11:31:55 -07:00
|
|
|
start_idx = tensor_model_parallel_rank * shard_size
|
|
|
|
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
|
|
|
loaded_weight = loaded_weight[start_idx:end_idx]
|
2023-05-03 15:32:04 +08:00
|
|
|
break
|
|
|
|
for p in row_parallel_weight_names:
|
|
|
|
if p in param_name:
|
|
|
|
shard_size = param.shape[1]
|
2023-07-03 11:31:55 -07:00
|
|
|
start_idx = tensor_model_parallel_rank * shard_size
|
|
|
|
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
|
|
|
loaded_weight = loaded_weight[:, start_idx:end_idx]
|
2023-05-03 15:32:04 +08:00
|
|
|
break
|
2023-08-30 16:00:13 +08:00
|
|
|
|
2023-09-07 15:49:52 -07:00
|
|
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
2023-06-29 22:14:17 -07:00
|
|
|
assert param.shape == loaded_weight.shape, (
|
|
|
|
f"{param_name} shape mismatch between model and checkpoint: "
|
|
|
|
f"{param.shape} != {loaded_weight.shape}")
|
2023-05-03 15:32:04 +08:00
|
|
|
param.data.copy_(loaded_weight)
|
2023-05-09 15:30:12 -07:00
|
|
|
|
|
|
|
|
|
|
|
def initialize_dummy_weights(
|
|
|
|
model: torch.nn.Module,
|
|
|
|
low: float = -1e-3,
|
|
|
|
high: float = 1e-3,
|
|
|
|
) -> None:
|
2023-05-14 22:32:38 -07:00
|
|
|
"""Initialize model weights with random values.
|
|
|
|
|
|
|
|
The model weights must be randomly initialized for accurate performance
|
|
|
|
measurements. Additionally, the model weights should not cause NaNs in the
|
|
|
|
forward pass. We empirically found that initializing the weights with
|
|
|
|
values between -1e-3 and 1e-3 works well for most models.
|
|
|
|
"""
|
2023-05-09 15:30:12 -07:00
|
|
|
for param in model.state_dict().values():
|
|
|
|
param.data.uniform_(low, high)
|