New weight loader without np copy (#52)

This commit is contained in:
Zhuohan Li 2023-05-03 15:32:04 +08:00 committed by GitHub
parent 4858f3bb45
commit 27f1410d06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 284 additions and 352 deletions

View File

@ -6,53 +6,15 @@ from tqdm import tqdm
import numpy as np
import torch
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
process_server_arguments,
initialize_cluster)
from cacheflow.master.server import (
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
def main(args: argparse.Namespace):
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
server, frontend = init_local_server_and_frontend_with_arguments(args)
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_cluster(
use_ray=args.use_ray,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
# Create a server.
server = Server(
model=args.model,
model_path=args.model_path,
use_dummy_weights=args.use_dummy_weights,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
)
# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
sampling_params_dict = {
'n': args.n,
'temperature': 0.0 if args.use_beam_search else 1.0,

View File

@ -9,57 +9,18 @@ from tqdm import tqdm
from transformers import AutoConfig
from benchmark.trace import generate_text_completion_requests
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
process_server_arguments,
initialize_cluster)
from cacheflow.master.server import (
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
logger = logging.getLogger(__name__)
def main(args: argparse.Namespace):
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
server, frontend = init_local_server_and_frontend_with_arguments(args)
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_cluster(
use_ray=args.use_ray,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
# Create a server.
server = Server(
model=args.model,
model_path=args.model_path,
use_dummy_weights=args.use_dummy_weights,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
collect_stats=True,
do_memory_analysis=args.do_memory_analysis,
)
# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
# Generate requests.
requests = generate_text_completion_requests(
args.dataset,

View File

@ -1,7 +1,7 @@
import argparse
import asyncio
import time
from typing import List, Dict
from typing import List, Dict, Optional
import json
import ray
@ -22,11 +22,12 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI()
class FastAPIFrontend:
class FastAPIServer:
def __init__(
self,
model: str,
model_path: str,
cache_dir: Optional[str],
use_np_cache: bool,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
@ -52,8 +53,9 @@ class FastAPIFrontend:
remote_server_class = ray.remote(num_gpus=1)(Server)
self.server = remote_server_class.remote(
model=model,
model_path=model_path,
cache_dir=cache_dir,
use_dummy_weights=False,
use_np_cache=use_np_cache,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
block_size=block_size,
@ -148,7 +150,7 @@ class FastAPIFrontend:
@app.post("/generate")
async def generate_stream(request: Request):
request_dict = await request.json()
return StreamingResponse(frontend.generate(request_dict))
return StreamingResponse(server.generate(request_dict))
if __name__ == "__main__":
@ -170,9 +172,10 @@ if __name__ == "__main__":
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
frontend = FastAPIFrontend(
server = FastAPIServer(
model=args.model,
model_path=args.model_path,
cache_dir=args.cache_dir,
use_np_cache=args.use_np_cache,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,

View File

@ -9,18 +9,21 @@ except ImportError:
ray = None
from cacheflow.master.scheduler import Scheduler
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.models import get_memory_analyzer
from cacheflow.worker.controller import Controller, DeviceID
from cacheflow.sequence import SequenceGroup
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
class Server:
def __init__(
self,
model: str,
model_path: str,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
@ -78,8 +81,9 @@ class Server:
num_cpu_blocks=self.num_cpu_blocks,
dtype=dtype,
seed=seed,
model_path=model_path,
cache_dir=cache_dir,
use_dummy_weights=use_dummy_weights,
use_np_cache=use_np_cache,
max_num_batched_tokens=max_num_batched_tokens,
use_ray=use_ray,
)
@ -203,25 +207,72 @@ def initialize_cluster(
def add_server_arguments(parser: argparse.ArgumentParser):
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
help='model path to download and load the weights')
parser.add_argument('--cache-dir', type=str, default=None,
help='cache dir to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument('--use-np-cache', action='store_true',
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
# Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
return parser
def process_server_arguments(args: argparse.Namespace):
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
args.use_ray = True
return args
def init_local_server_and_frontend_with_arguments(args: argparse.Namespace):
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_cluster(
use_ray=args.use_ray,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
# Create a server.
server = Server(
model=args.model,
cache_dir=args.cache_dir,
use_dummy_weights=args.use_dummy_weights,
use_np_cache=args.use_np_cache,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
)
# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
return server, frontend

View File

@ -1,18 +1,14 @@
"""1D GPT-NeoX model compatible with HuggingFace weights."""
import os
import glob
import filelock
from tqdm import tqdm
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from huggingface_hub import snapshot_download
from cacheflow.models import InputMetadata
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
@ -196,17 +192,22 @@ class GPTNeoXForCausalLM(nn.Module):
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, weights_path: str):
def load_weights(self, model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, param in state_dict.items():
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
continue
param = state_dict[name]
if "query_key_value" in name:
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
# [num_heads * 3 * head_size, num_heads * head_size], while the
# required shape is [3 * num_heads * head_size, num_heads * head_size].
# Thus, we need weight conversion.
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
shard_size = param.shape[0]
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
@ -223,55 +224,10 @@ class GPTNeoXForCausalLM(nn.Module):
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1).contiguous()
else:
assert False
else:
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
@staticmethod
def get_weights(model_name: str, path: str):
path = os.path.join(path, f"{model_name}-np")
path = os.path.abspath(os.path.expanduser(path))
os.makedirs(path, exist_ok=True)
lock_path = os.path.join(path, "file_lock")
lock = filelock.FileLock(lock_path)
with lock:
test_weight_path = os.path.join(
path, "gpt_neox.embed_in.weight")
if os.path.exists(test_weight_path):
return path
folder = snapshot_download(model_name, allow_patterns="*.bin",
cache_dir=os.path.join(path, "cache"))
bin_files = glob.glob(os.path.join(folder, "*.bin"))
for bin_file in tqdm(bin_files, desc="Convert format"):
state = torch.load(bin_file, map_location="cpu")
for name, param in tqdm(state.items(), leave=False):
param_path = os.path.join(path, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
return path
raise ValueError(f"Unexpected weight name: {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights)
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():

View File

@ -1,11 +1,6 @@
"""1D LLaMA model compatible with HuggingFace weights."""
import os
import glob
import filelock
from tqdm import tqdm
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from transformers import LlamaConfig
@ -15,6 +10,8 @@ from cacheflow.models.activation import SiluAndMul
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
@ -216,76 +213,57 @@ class LlamaForCausalLM(nn.Module):
"up_proj.weight"]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, weights_path: str):
def load_weights(self, model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, param in state_dict.items():
if "qkv_proj" in name or "gate_up_proj" in name:
if "qkv_proj" in name:
original_name = "qkv_proj"
weight_names = ["q_proj", "k_proj", "v_proj"]
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name:
continue
is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
else:
original_name = "gate_up_proj"
weight_names = ["gate_proj", "up_proj"]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id
:shard_size * (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
weights_to_concat = []
for weight_name in weight_names:
weight = np.load(os.path.join(
weights_path, name.replace(original_name, weight_name)))
weights_to_concat.append(weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)])
loaded_weight = torch.from_numpy(
np.concatenate(weights_to_concat, axis=0))
else:
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id
:shard_size * (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
if is_gate_up_weight:
continue
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
@staticmethod
def get_weights(model_name: str, path: str):
if not os.path.isfile(os.path.join(model_name, "config.json")):
raise ValueError("LLaMA model's model_name has to be a path"
"to the huggingface model's directory.")
path = os.path.join(model_name, f"np")
path = os.path.abspath(os.path.expanduser(path))
os.makedirs(path, exist_ok=True)
lock_path = os.path.join(path, "file_lock")
lock = filelock.FileLock(lock_path)
with lock:
test_weight_path = os.path.join(path, "model.embed_tokens.weight")
if os.path.exists(test_weight_path):
return path
bin_files = glob.glob(os.path.join(model_name, "*.bin"))
for bin_file in tqdm(bin_files, desc="Convert format"):
state = torch.load(bin_file, map_location="cpu")
for name, param in tqdm(state.items(), leave=False):
param_path = os.path.join(path, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
return path
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights)
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():

View File

@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Optional
import torch
import torch.nn as nn
@ -32,8 +32,9 @@ _MEMORY_ANALYZERS = {
def get_model(
model_name: str,
dtype: Union[torch.dtype, str],
path: str,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
) -> nn.Module:
torch_dtype = get_torch_dtype(dtype)
torch.set_default_dtype(torch_dtype)
@ -49,12 +50,10 @@ def get_model(
# random values to the weights.
model.initialize_dummy_weights()
else:
# Download model weights if it's not cached.
weights_dir = model_class.get_weights(model_name, path=path)
# Create a model instance.
model = model_class(config)
# Load the weights from the cached or downloaded files.
model.load_weights(weights_dir)
model.load_weights(model_name, cache_dir, use_np_cache)
model = model.cuda()
return model.eval(), torch_dtype
raise ValueError(f'Unsupported model name: {model_name}')

View File

@ -1,19 +1,15 @@
"""1D OPT model compatible with HuggingFace weights."""
import os
import glob
import filelock
from tqdm import tqdm
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from transformers import OPTConfig
from huggingface_hub import snapshot_download
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
@ -257,73 +253,42 @@ class OPTForCausalLM(nn.Module):
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
def load_weights(self, weights_path: str):
def load_weights(self, model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, param in state_dict.items():
if "lm_head_weight" in name:
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name:
continue
if "qkv_proj" in name:
shard_size = param.shape[0] // 3
weights_to_concat = []
for weight_name in ["q_proj", "k_proj", "v_proj"]:
weight = np.load(os.path.join(
weights_path, name.replace("qkv_proj", weight_name)))
weights_to_concat.append(weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)])
loaded_weight = torch.from_numpy(
np.concatenate(weights_to_concat, axis=0))
else:
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
@staticmethod
def get_weights(model_name: str, path: str):
path = os.path.join(path, f"{model_name}-np")
path = os.path.abspath(os.path.expanduser(path))
os.makedirs(path, exist_ok=True)
lock_path = os.path.join(path, "file_lock")
lock = filelock.FileLock(lock_path)
with lock:
test_weight_path = os.path.join(
path, "model.decoder.embed_positions.weight")
if os.path.exists(test_weight_path):
return path
folder = snapshot_download(model_name, allow_patterns="*.bin",
cache_dir=os.path.join(path, "cache"))
bin_files = glob.glob(os.path.join(folder, "*.bin"))
for bin_file in tqdm(bin_files, desc="Convert format"):
state = torch.load(bin_file, map_location="cpu")
for name, param in tqdm(state.items(), leave=False):
if name.startswith("decoder."):
name = "model." + name
param_path = os.path.join(path, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
return path
is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id
:shard_size * (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights)
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():

View File

@ -1,6 +1,16 @@
from typing import Union
import os
import glob
import json
import filelock
from typing import Union, Optional
import numpy as np
import torch
from tqdm.auto import tqdm
from huggingface_hub import snapshot_download
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank)
_STR_DTYPE_TO_TORCH_DTYPE = {
'half': torch.half,
@ -22,3 +32,86 @@ def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
torch_dtype = get_torch_dtype(dtype)
return torch.tensor([], dtype=torch_dtype).element_size()
class Disabledtqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
def hf_model_weights_iterator(model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
# Prepare file lock directory to prevent multiple processes from
# downloading the same model weights at the same time.
lock_dir = cache_dir if cache_dir is not None else "/tmp"
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
if not is_local:
with lock:
hf_folder = snapshot_download(model_name_or_path,
allow_patterns="*.bin",
cache_dir=cache_dir,
tqdm_class=Disabledtqdm)
else:
hf_folder = model_name_or_path
hf_bin_files = glob.glob(os.path.join(hf_folder, "*.bin"))
if use_np_cache:
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, 'np')
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, 'weight_names.json')
with lock:
if not os.path.exists(weight_names_file):
weight_names = []
for bin_file in hf_bin_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
weight_names.append(name)
with open(weight_names_file, 'w') as f:
json.dump(weight_names, f)
with open(weight_names_file, 'r') as f:
weight_names = json.load(f)
for name in weight_names:
param_path = os.path.join(np_folder, name)
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
else:
for bin_file in hf_bin_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
yield name, param
def load_tensor_parallel_weights(param, loaded_weight, param_name,
column_parallel_weight_names,
row_parallel_weight_names):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
for p in column_parallel_weight_names:
if p in param_name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in row_parallel_weight_names:
if p in param_name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Union, Tuple
from typing import Dict, List, Union, Tuple, Optional
try:
import ray
@ -29,8 +29,9 @@ class Controller:
num_cpu_blocks: int,
dtype: str,
seed: int,
model_path: str,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
max_num_batched_tokens: int,
use_ray: bool,
) -> None:
@ -66,8 +67,9 @@ class Controller:
world_size=world_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
model_path=model_path,
cache_dir=cache_dir,
use_dummy_weights=use_dummy_weights,
use_np_cache=use_np_cache,
max_num_batched_tokens=max_num_batched_tokens,
)
self.workers.append(worker)

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Optional
import torch
@ -28,8 +28,9 @@ class Worker:
distributed_init_method: str,
rank: int,
world_size: int,
model_path: str,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
max_num_batched_tokens: int,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
@ -45,7 +46,8 @@ class Worker:
# Initialize the model.
self.model, self.dtype = get_model(
model_name, dtype=dtype, path=model_path, use_dummy_weights=use_dummy_weights)
model_name, dtype=dtype, cache_dir=cache_dir,
use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache)
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
initialize_all_reduce_launcher(

View File

@ -1,53 +1,13 @@
import argparse
from typing import List
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
process_server_arguments,
initialize_cluster)
from cacheflow.master.server import (
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
def main(args: argparse.Namespace):
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_cluster(
use_ray=args.use_ray,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
# Create a server.
server = Server(
model=args.model,
model_path=args.model_path,
use_dummy_weights=args.use_dummy_weights,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
)
# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
server, frontend = init_local_server_and_frontend_with_arguments(args)
# Test the following inputs.
test_inputs = [
('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}),