Accelerate LLaMA model loading (#234)

This commit is contained in:
JFDuan 2023-08-30 16:00:13 +08:00 committed by GitHub
parent becd7a56f1
commit 0d93f15694
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 190 additions and 112 deletions

View File

@ -34,8 +34,9 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
@ -280,8 +281,7 @@ class AquilaForCausalLM(nn.Module):
return next_tokens
_column_parallel_weights = [
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
"gate_proj.weight", "up_proj.weight"
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
@ -309,16 +309,6 @@ class AquilaForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name:
continue
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
# Consider padding in the vocab size.
padded_vocab_size = (param.shape[0] * tp_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
@ -356,6 +346,11 @@ class AquilaForCausalLM(nn.Module):
continue
param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,

View File

@ -32,10 +32,12 @@ from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
PagedAttentionWithALiBi)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
@ -295,10 +297,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
input_metadata)
return next_tokens
_column_parallel_weights = [
"embed_tokens.weight",
"lm_head.weight",
]
_column_parallel_weights = []
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self,
@ -314,16 +313,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name:
continue
if "embed_tokens" in name or "lm_head" in name:
# Consider padding in the vocab size.
param = state_dict[name]
padded_vocab_size = param.shape[0] * tp_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
if "W_pack" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
@ -355,6 +344,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
continue
param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights(
param,
loaded_weight,

View File

@ -31,8 +31,9 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
@ -224,7 +225,7 @@ class GPT2LMHeadModel(nn.Module):
input_metadata)
return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self,
@ -261,14 +262,9 @@ class GPT2LMHeadModel(nn.Module):
param = state_dict[name]
if name == "transformer.wte.weight":
# Consider padding in the vocab size.
padded_vocab_size = (param.shape[0] *
tensor_model_parallel_world_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:

View File

@ -32,8 +32,9 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
@ -252,7 +253,7 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata)
return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self,
@ -328,14 +329,9 @@ class GPTBigCodeForCausalLM(nn.Module):
param = state_dict[name]
if name == "transformer.wte.weight":
# Consider padding in the vocab size.
padded_vocab_size = param.shape[
0] * tensor_model_parallel_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,

View File

@ -14,8 +14,9 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -225,8 +226,7 @@ class InternLMForCausalLM(nn.Module):
return next_tokens
_column_parallel_weights = [
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
"gate_proj.weight", "up_proj.weight"
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
@ -234,8 +234,6 @@ class InternLMForCausalLM(nn.Module):
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
@ -246,14 +244,9 @@ class InternLMForCausalLM(nn.Module):
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
# Consider padding in the vocab size.
padded_vocab_size = (param.shape[0] *
tensor_model_parallel_world_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
is_attention_weight = False
for stride_id, att_weight_name in enumerate(

View File

@ -36,8 +36,9 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
hf_model_weights_iterator)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
@ -263,15 +264,15 @@ class LlamaForCausalLM(nn.Module):
return next_tokens
_column_parallel_weights = [
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
"gate_proj.weight", "up_proj.weight"
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
use_np_cache: bool = False,
use_safetensor: bool = True):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
@ -288,20 +289,10 @@ class LlamaForCausalLM(nn.Module):
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, use_np_cache, use_safetensor):
if "rotary_emb.inv_freq" in name:
continue
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
# Consider padding in the vocab size.
padded_vocab_size = (param.shape[0] * tp_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
@ -339,6 +330,12 @@ class LlamaForCausalLM(nn.Module):
continue
param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,

View File

@ -20,6 +20,7 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator,
load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights,
)
from vllm.model_executor.parallel_utils.parallel_state import (
@ -241,7 +242,7 @@ class QWenLMHeadModel(nn.Module):
input_metadata)
return next_tokens
_column_parallel_weights = ["wte.weight", "lm_head.weight"]
_column_parallel_weights = []
_row_parallel_weights = ["c_proj.weight"]
def load_weights(
@ -259,16 +260,6 @@ class QWenLMHeadModel(nn.Module):
if "rotary_emb.inv_freq" in name:
continue
if "wte" in name or "lm_head" in name:
# Consider padding in the vocab size.
param = state_dict[name]
padded_vocab_size = param.shape[0] * tp_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
if "c_attn" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
@ -306,6 +297,12 @@ class QWenLMHeadModel(nn.Module):
continue
param = state_dict[name]
if "wte" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights(
param,
loaded_weight,

View File

@ -3,13 +3,19 @@ import filelock
import glob
import json
import os
from typing import Iterator, List, Optional, Tuple
from collections import defaultdict
from typing import Iterator, List, Optional, Tuple, Any
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file, safe_open
import numpy as np
import torch
from tqdm.auto import tqdm
from vllm.logger import init_logger
logger = init_logger(__name__)
class Disabledtqdm(tqdm):
@ -17,43 +23,118 @@ class Disabledtqdm(tqdm):
super().__init__(*args, **kwargs, disable=True)
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
) -> Iterator[Tuple[str, torch.Tensor]]:
# Prepare file lock directory to prevent multiple processes from
# downloading the same model weights at the same time.
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
lock_dir = cache_dir if cache_dir is not None else "/tmp"
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
return lock
def _shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
failing = []
for _, names in ptrs.items():
if len(names) > 1:
failing.append(names)
return failing
def convert_bin_to_safetensor_file(
pt_filename: str,
sf_filename: str,
):
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
shared = _shared_pointers(loaded)
for shared_weights in shared:
for name in shared_weights[1:]:
loaded.pop(name)
# For tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_filename, metadata={"format": "pt"})
# check file size
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
""")
# check if the tensors are the same
reloaded = load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
def prepare_hf_model_weights(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_safetensor: bool = False,
):
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
allow_patterns = "*.safetensors" if use_safetensor else "*.bin"
if not is_local:
with lock:
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns="*.bin",
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=Disabledtqdm)
else:
hf_folder = model_name_or_path
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
if not use_safetensor:
hf_weights_files = [
x for x in hf_weights_files if not x.endswith("training_args.bin")
]
hf_bin_files = [
x for x in glob.glob(os.path.join(hf_folder, "*.bin"))
if not x.endswith("training_args.bin")
]
if len(hf_weights_files) == 0 and use_safetensor:
logger.warning("No *.safetensors files found, "
"fall back to *.bin files")
return prepare_hf_model_weights(model_name_or_path,
cache_dir=cache_dir,
use_safetensor=False)
return hf_folder, hf_weights_files, use_safetensor
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
use_safetensor: bool = False,
) -> Iterator[Tuple[str, torch.Tensor]]:
hf_folder, hf_weights_files, use_safetensor = prepare_hf_model_weights(
model_name_or_path, cache_dir=cache_dir, use_safetensor=use_safetensor)
if use_np_cache:
# Currently np_cache only support *.bin checkpoints
assert use_safetensor is False
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, "weight_names.json")
with lock:
# Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file):
weight_names = []
for bin_file in hf_bin_files:
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
param_path = os.path.join(np_folder, name)
@ -71,8 +152,14 @@ def hf_model_weights_iterator(
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
elif use_safetensor:
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys():
param = f.get_slice(name)
yield name, param
else:
for bin_file in hf_bin_files:
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
yield name, param
@ -80,9 +167,26 @@ def hf_model_weights_iterator(
torch.cuda.empty_cache()
def load_padded_tensor_parallel_vocab(
param: torch.Tensor,
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
tensor_model_parallel_rank: int,
) -> None:
shard_size = param.shape[0]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[start_idx:end_idx]
# convert PySafeSlice object to torch.Tensor
if not isinstance(loaded_weight, torch.Tensor):
loaded_weight = loaded_weight[:]
param[:loaded_weight.shape[0]].copy_(loaded_weight)
def load_tensor_parallel_weights(
param: torch.Tensor,
loaded_weight: torch.Tensor,
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
param_name: str,
column_parallel_weight_names: List[str],
row_parallel_weight_names: List[str],
@ -102,6 +206,11 @@ def load_tensor_parallel_weights(
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[:, start_idx:end_idx]
break
# convert PySafeSlice object to torch.Tensor
if not isinstance(loaded_weight, torch.Tensor):
loaded_weight = loaded_weight[:]
assert param.shape == loaded_weight.shape, (
f"{param_name} shape mismatch between model and checkpoint: "
f"{param.shape} != {loaded_weight.shape}")