vllm/vllm/model_executor/weight_utils.py

322 lines
12 KiB
Python
Raw Normal View History

"""Utilities for downloading and initializing model weights."""
2023-05-09 15:30:12 -07:00
import filelock
import glob
import json
2023-05-09 15:30:12 -07:00
import os
2023-08-30 16:00:13 +08:00
from collections import defaultdict
from typing import Any, Iterator, List, Optional, Tuple
2023-05-09 15:30:12 -07:00
from huggingface_hub import snapshot_download
2023-08-30 16:00:13 +08:00
from safetensors.torch import load_file, save_file, safe_open
import numpy as np
import torch
from tqdm.auto import tqdm
2023-08-30 16:00:13 +08:00
from vllm.logger import init_logger
from vllm.model_executor.quantization_utils import get_quant_class
from vllm.model_executor.quantization_utils.base import QuantizationConfig
2023-08-30 16:00:13 +08:00
logger = init_logger(__name__)
class Disabledtqdm(tqdm):
2023-05-09 15:30:12 -07:00
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
2023-08-30 16:00:13 +08:00
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
lock_dir = cache_dir if cache_dir is not None else "/tmp"
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
2023-08-30 16:00:13 +08:00
return lock
def _shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
failing = []
for _, names in ptrs.items():
if len(names) > 1:
failing.append(names)
return failing
def convert_bin_to_safetensor_file(
pt_filename: str,
sf_filename: str,
) -> None:
2023-08-30 16:00:13 +08:00
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
shared = _shared_pointers(loaded)
for shared_weights in shared:
for name in shared_weights[1:]:
loaded.pop(name)
# For tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_filename, metadata={"format": "pt"})
# check file size
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
""")
# check if the tensors are the same
reloaded = load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
# TODO(woosuk): Move this to other place.
def get_quant_config(
quantization: str,
model_name_or_path: str,
cache_dir: Optional[str] = None,
) -> QuantizationConfig:
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns="*.json",
cache_dir=cache_dir,
tqdm_class=Disabledtqdm)
else:
hf_folder = model_name_or_path
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_cls = get_quant_class(quantization)
quant_config_files = [
f for f in config_files if any(
f.endswith(x) for x in quant_cls.get_config_filenames())
]
if len(quant_config_files) == 0:
raise ValueError(f"Cannot find the config file for {quantization}")
if len(quant_config_files) > 1:
raise ValueError(f"Found multiple config files for {quantization}: "
f"{quant_config_files}")
quant_config_file = quant_config_files[0]
with open(quant_config_file, "r") as f:
config = json.load(f)
return quant_cls.from_config(config)
2023-08-30 16:00:13 +08:00
def prepare_hf_model_weights(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_safetensors: bool = False,
fall_back_to_pt: bool = True,
revision: Optional[str] = None,
) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
if use_safetensors:
allow_patterns = ["*.safetensors"]
else:
# Some quantized models use .pt files for storing the weights.
allow_patterns = ["*.bin", "*.pt"]
if not is_local:
2023-08-30 16:00:13 +08:00
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
2023-08-30 16:00:13 +08:00
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=Disabledtqdm,
revision=revision)
else:
hf_folder = model_name_or_path
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if not use_safetensors:
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
2023-08-30 16:00:13 +08:00
hf_weights_files = [
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
2023-08-30 16:00:13 +08:00
]
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
2023-08-30 16:00:13 +08:00
return prepare_hf_model_weights(model_name_or_path,
cache_dir=cache_dir,
use_safetensors=False,
fall_back_to_pt=False,
revision=revision)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
2023-08-30 16:00:13 +08:00
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
2023-08-30 16:00:13 +08:00
) -> Iterator[Tuple[str, torch.Tensor]]:
use_safetensors = False
use_np_cache = False
fall_back_to_pt = False
if load_format == "auto":
use_safetensors = True
fall_back_to_pt = True
elif load_format == "safetensors":
use_safetensors = True
elif load_format == "pt":
pass
elif load_format == "npcache":
use_np_cache = True
else:
raise ValueError(f"Unknown load_format: {load_format}")
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
model_name_or_path,
cache_dir=cache_dir,
use_safetensors=use_safetensors,
fall_back_to_pt=fall_back_to_pt,
revision=revision)
if use_np_cache:
2023-08-30 16:00:13 +08:00
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
2023-08-30 16:00:13 +08:00
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, "weight_names.json")
2023-08-30 16:00:13 +08:00
# Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file):
weight_names = []
2023-08-30 16:00:13 +08:00
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
weight_names.append(name)
with open(weight_names_file, "w") as f:
json.dump(weight_names, f)
with open(weight_names_file, "r") as f:
weight_names = json.load(f)
for name in weight_names:
param_path = os.path.join(np_folder, name)
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
elif use_safetensors:
2023-08-30 16:00:13 +08:00
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys():
param = f.get_slice(name)
yield name, param
else:
2023-08-30 16:00:13 +08:00
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
yield name, param
2023-08-18 03:56:04 +08:00
del state
torch.cuda.empty_cache()
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
"""convert PySafeSlice object from safetensors to torch.Tensor
PySafeSlice object supports indexing, which is done before loading the
actual tensor and can reduce the amount of memory being read into the
memory. However, it does not support more advanced functionalities
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
tensor with these more complicated operators, we need to convert to
tensor first.
"""
if not isinstance(x, torch.Tensor):
x = x[:]
return x
2023-08-30 16:00:13 +08:00
def load_padded_tensor_parallel_vocab(
param: torch.Tensor,
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
tensor_model_parallel_rank: int,
) -> None:
shard_size = param.shape[0]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[start_idx:end_idx]
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
2023-08-30 16:00:13 +08:00
param[:loaded_weight.shape[0]].copy_(loaded_weight)
2023-05-09 15:30:12 -07:00
def load_tensor_parallel_weights(
param: torch.Tensor,
2023-08-30 16:00:13 +08:00
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
2023-05-09 15:30:12 -07:00
param_name: str,
column_parallel_weight_names: List[str],
row_parallel_weight_names: List[str],
tensor_model_parallel_rank: int,
) -> None:
for p in column_parallel_weight_names:
if p in param_name:
shard_size = param.shape[0]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[start_idx:end_idx]
break
for p in row_parallel_weight_names:
if p in param_name:
shard_size = param.shape[1]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[:, start_idx:end_idx]
break
2023-08-30 16:00:13 +08:00
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
assert param.shape == loaded_weight.shape, (
f"{param_name} shape mismatch between model and checkpoint: "
f"{param.shape} != {loaded_weight.shape}")
param.data.copy_(loaded_weight)
2023-05-09 15:30:12 -07:00
def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
) -> None:
"""Initialize model weights with random values.
The model weights must be randomly initialized for accurate performance
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
"""
2023-05-09 15:30:12 -07:00
for param in model.state_dict().values():
param.data.uniform_(low, high)