New weight loader without np copy (#52)
This commit is contained in:
parent
4858f3bb45
commit
27f1410d06
@ -6,53 +6,15 @@ from tqdm import tqdm
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from cacheflow.master.simple_frontend import SimpleFrontend
|
from cacheflow.master.server import (
|
||||||
from cacheflow.master.server import (Server, add_server_arguments,
|
add_server_arguments, process_server_arguments,
|
||||||
process_server_arguments,
|
init_local_server_and_frontend_with_arguments)
|
||||||
initialize_cluster)
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
# TODO(zhuohan): Support pipeline parallelism.
|
server, frontend = init_local_server_and_frontend_with_arguments(args)
|
||||||
assert args.pipeline_parallel_size == 1, (
|
|
||||||
'Pipeline parallelism is not supported yet.')
|
|
||||||
|
|
||||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
|
||||||
all_stage_devices) = (
|
|
||||||
initialize_cluster(
|
|
||||||
use_ray=args.use_ray,
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size))
|
|
||||||
|
|
||||||
# Create a server.
|
|
||||||
server = Server(
|
|
||||||
model=args.model,
|
|
||||||
model_path=args.model_path,
|
|
||||||
use_dummy_weights=args.use_dummy_weights,
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
|
||||||
block_size=args.block_size,
|
|
||||||
dtype=args.dtype,
|
|
||||||
seed=args.seed,
|
|
||||||
swap_space=args.swap_space,
|
|
||||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
|
||||||
max_num_sequences=args.max_num_sequences,
|
|
||||||
num_nodes=num_nodes,
|
|
||||||
num_devices_per_node=num_devices_per_node,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
all_stage_devices=all_stage_devices,
|
|
||||||
gpu_memory=get_gpu_memory(),
|
|
||||||
cpu_memory=get_cpu_memory(),
|
|
||||||
use_ray=args.use_ray,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a frontend.
|
|
||||||
frontend = SimpleFrontend(
|
|
||||||
model_name=args.model,
|
|
||||||
block_size=args.block_size,
|
|
||||||
)
|
|
||||||
sampling_params_dict = {
|
sampling_params_dict = {
|
||||||
'n': args.n,
|
'n': args.n,
|
||||||
'temperature': 0.0 if args.use_beam_search else 1.0,
|
'temperature': 0.0 if args.use_beam_search else 1.0,
|
||||||
|
@ -9,57 +9,18 @@ from tqdm import tqdm
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from benchmark.trace import generate_text_completion_requests
|
from benchmark.trace import generate_text_completion_requests
|
||||||
from cacheflow.master.simple_frontend import SimpleFrontend
|
from cacheflow.master.server import (
|
||||||
from cacheflow.master.server import (Server, add_server_arguments,
|
add_server_arguments, process_server_arguments,
|
||||||
process_server_arguments,
|
init_local_server_and_frontend_with_arguments)
|
||||||
initialize_cluster)
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
assert args.pipeline_parallel_size == 1, (
|
server, frontend = init_local_server_and_frontend_with_arguments(args)
|
||||||
'Pipeline parallelism is not supported yet.')
|
|
||||||
|
|
||||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
|
||||||
all_stage_devices) = (
|
|
||||||
initialize_cluster(
|
|
||||||
use_ray=args.use_ray,
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size))
|
|
||||||
|
|
||||||
# Create a server.
|
|
||||||
server = Server(
|
|
||||||
model=args.model,
|
|
||||||
model_path=args.model_path,
|
|
||||||
use_dummy_weights=args.use_dummy_weights,
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
|
||||||
block_size=args.block_size,
|
|
||||||
dtype=args.dtype,
|
|
||||||
seed=args.seed,
|
|
||||||
swap_space=args.swap_space,
|
|
||||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
|
||||||
max_num_sequences=args.max_num_sequences,
|
|
||||||
num_nodes=num_nodes,
|
|
||||||
num_devices_per_node=num_devices_per_node,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
all_stage_devices=all_stage_devices,
|
|
||||||
gpu_memory=get_gpu_memory(),
|
|
||||||
cpu_memory=get_cpu_memory(),
|
|
||||||
use_ray=args.use_ray,
|
|
||||||
collect_stats=True,
|
|
||||||
do_memory_analysis=args.do_memory_analysis,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a frontend.
|
|
||||||
frontend = SimpleFrontend(
|
|
||||||
model_name=args.model,
|
|
||||||
block_size=args.block_size,
|
|
||||||
)
|
|
||||||
# Generate requests.
|
# Generate requests.
|
||||||
requests = generate_text_completion_requests(
|
requests = generate_text_completion_requests(
|
||||||
args.dataset,
|
args.dataset,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Optional
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
@ -22,11 +22,12 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
|||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
class FastAPIFrontend:
|
class FastAPIServer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
model_path: str,
|
cache_dir: Optional[str],
|
||||||
|
use_np_cache: bool,
|
||||||
pipeline_parallel_size: int,
|
pipeline_parallel_size: int,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
@ -52,8 +53,9 @@ class FastAPIFrontend:
|
|||||||
remote_server_class = ray.remote(num_gpus=1)(Server)
|
remote_server_class = ray.remote(num_gpus=1)(Server)
|
||||||
self.server = remote_server_class.remote(
|
self.server = remote_server_class.remote(
|
||||||
model=model,
|
model=model,
|
||||||
model_path=model_path,
|
cache_dir=cache_dir,
|
||||||
use_dummy_weights=False,
|
use_dummy_weights=False,
|
||||||
|
use_np_cache=use_np_cache,
|
||||||
pipeline_parallel_size=pipeline_parallel_size,
|
pipeline_parallel_size=pipeline_parallel_size,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
@ -148,7 +150,7 @@ class FastAPIFrontend:
|
|||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
async def generate_stream(request: Request):
|
async def generate_stream(request: Request):
|
||||||
request_dict = await request.json()
|
request_dict = await request.json()
|
||||||
return StreamingResponse(frontend.generate(request_dict))
|
return StreamingResponse(server.generate(request_dict))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -170,9 +172,10 @@ if __name__ == "__main__":
|
|||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||||
tensor_parallel_size=args.tensor_parallel_size))
|
tensor_parallel_size=args.tensor_parallel_size))
|
||||||
|
|
||||||
frontend = FastAPIFrontend(
|
server = FastAPIServer(
|
||||||
model=args.model,
|
model=args.model,
|
||||||
model_path=args.model_path,
|
cache_dir=args.cache_dir,
|
||||||
|
use_np_cache=args.use_np_cache,
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
block_size=args.block_size,
|
block_size=args.block_size,
|
||||||
|
@ -9,18 +9,21 @@ except ImportError:
|
|||||||
ray = None
|
ray = None
|
||||||
|
|
||||||
from cacheflow.master.scheduler import Scheduler
|
from cacheflow.master.scheduler import Scheduler
|
||||||
|
from cacheflow.master.simple_frontend import SimpleFrontend
|
||||||
from cacheflow.models import get_memory_analyzer
|
from cacheflow.models import get_memory_analyzer
|
||||||
from cacheflow.worker.controller import Controller, DeviceID
|
from cacheflow.worker.controller import Controller, DeviceID
|
||||||
from cacheflow.sequence import SequenceGroup
|
from cacheflow.sequence import SequenceGroup
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
|
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
||||||
|
|
||||||
|
|
||||||
class Server:
|
class Server:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
model_path: str,
|
cache_dir: Optional[str],
|
||||||
use_dummy_weights: bool,
|
use_dummy_weights: bool,
|
||||||
|
use_np_cache: bool,
|
||||||
pipeline_parallel_size: int,
|
pipeline_parallel_size: int,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
@ -78,8 +81,9 @@ class Server:
|
|||||||
num_cpu_blocks=self.num_cpu_blocks,
|
num_cpu_blocks=self.num_cpu_blocks,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
model_path=model_path,
|
cache_dir=cache_dir,
|
||||||
use_dummy_weights=use_dummy_weights,
|
use_dummy_weights=use_dummy_weights,
|
||||||
|
use_np_cache=use_np_cache,
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
use_ray=use_ray,
|
use_ray=use_ray,
|
||||||
)
|
)
|
||||||
@ -203,25 +207,72 @@ def initialize_cluster(
|
|||||||
def add_server_arguments(parser: argparse.ArgumentParser):
|
def add_server_arguments(parser: argparse.ArgumentParser):
|
||||||
# Model arguments
|
# Model arguments
|
||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
|
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
|
||||||
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
|
parser.add_argument('--cache-dir', type=str, default=None,
|
||||||
help='model path to download and load the weights')
|
help='cache dir to download and load the weights, '
|
||||||
|
'default to the default cache dir of huggingface')
|
||||||
|
parser.add_argument('--use-np-cache', action='store_true',
|
||||||
|
help='save a numpy copy of model weights for faster loading')
|
||||||
|
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
|
||||||
|
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
||||||
|
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
|
||||||
# Parallel arguments
|
# Parallel arguments
|
||||||
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
|
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
|
||||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
||||||
# KV cache arguments
|
# KV cache arguments
|
||||||
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
|
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
|
||||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
|
||||||
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
|
|
||||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||||
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
||||||
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
|
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
|
||||||
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
|
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
|
||||||
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def process_server_arguments(args: argparse.Namespace):
|
def process_server_arguments(args: argparse.Namespace):
|
||||||
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
|
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
|
||||||
args.use_ray = True
|
args.use_ray = True
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def init_local_server_and_frontend_with_arguments(args: argparse.Namespace):
|
||||||
|
# TODO(zhuohan): Support pipeline parallelism.
|
||||||
|
assert args.pipeline_parallel_size == 1, (
|
||||||
|
'Pipeline parallelism is not supported yet.')
|
||||||
|
|
||||||
|
(num_nodes, num_devices_per_node, distributed_init_method,
|
||||||
|
all_stage_devices) = (
|
||||||
|
initialize_cluster(
|
||||||
|
use_ray=args.use_ray,
|
||||||
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||||
|
tensor_parallel_size=args.tensor_parallel_size))
|
||||||
|
|
||||||
|
# Create a server.
|
||||||
|
server = Server(
|
||||||
|
model=args.model,
|
||||||
|
cache_dir=args.cache_dir,
|
||||||
|
use_dummy_weights=args.use_dummy_weights,
|
||||||
|
use_np_cache=args.use_np_cache,
|
||||||
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||||
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
|
block_size=args.block_size,
|
||||||
|
dtype=args.dtype,
|
||||||
|
seed=args.seed,
|
||||||
|
swap_space=args.swap_space,
|
||||||
|
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||||
|
max_num_sequences=args.max_num_sequences,
|
||||||
|
num_nodes=num_nodes,
|
||||||
|
num_devices_per_node=num_devices_per_node,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
|
all_stage_devices=all_stage_devices,
|
||||||
|
gpu_memory=get_gpu_memory(),
|
||||||
|
cpu_memory=get_cpu_memory(),
|
||||||
|
use_ray=args.use_ray,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a frontend.
|
||||||
|
frontend = SimpleFrontend(
|
||||||
|
model_name=args.model,
|
||||||
|
block_size=args.block_size,
|
||||||
|
)
|
||||||
|
return server, frontend
|
||||||
|
@ -1,18 +1,14 @@
|
|||||||
"""1D GPT-NeoX model compatible with HuggingFace weights."""
|
"""1D GPT-NeoX model compatible with HuggingFace weights."""
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import filelock
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
from cacheflow.models import InputMetadata
|
from cacheflow.models import InputMetadata
|
||||||
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
|
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
|
||||||
from cacheflow.models.sample import Sampler
|
from cacheflow.models.sample import Sampler
|
||||||
|
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||||
|
load_tensor_parallel_weights)
|
||||||
from cacheflow.parallel_utils.parallel_state import (
|
from cacheflow.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
||||||
@ -196,17 +192,22 @@ class GPTNeoXForCausalLM(nn.Module):
|
|||||||
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
|
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
|
||||||
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||||
|
|
||||||
def load_weights(self, weights_path: str):
|
def load_weights(self, model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
use_np_cache: bool = False):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, param in state_dict.items():
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, use_np_cache):
|
||||||
|
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||||
|
or "rotary_emb.inv_freq" in name):
|
||||||
|
continue
|
||||||
|
param = state_dict[name]
|
||||||
if "query_key_value" in name:
|
if "query_key_value" in name:
|
||||||
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
|
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
|
||||||
# [num_heads * 3 * head_size, num_heads * head_size], while the
|
# [num_heads * 3 * head_size, num_heads * head_size], while the
|
||||||
# required shape is [3 * num_heads * head_size, num_heads * head_size].
|
# required shape is [3 * num_heads * head_size, num_heads * head_size].
|
||||||
# Thus, we need weight conversion.
|
# Thus, we need weight conversion.
|
||||||
loaded_weight = torch.from_numpy(
|
|
||||||
np.load(os.path.join(weights_path, name)))
|
|
||||||
shard_size = param.shape[0]
|
shard_size = param.shape[0]
|
||||||
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
|
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
@ -223,55 +224,10 @@ class GPTNeoXForCausalLM(nn.Module):
|
|||||||
loaded_weight = loaded_weight.transpose(0, 1)
|
loaded_weight = loaded_weight.transpose(0, 1)
|
||||||
loaded_weight = loaded_weight.reshape(-1).contiguous()
|
loaded_weight = loaded_weight.reshape(-1).contiguous()
|
||||||
else:
|
else:
|
||||||
assert False
|
raise ValueError(f"Unexpected weight name: {name}")
|
||||||
else:
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
loaded_weight = torch.from_numpy(
|
self._column_parallel_weights,
|
||||||
np.load(os.path.join(weights_path, name)))
|
self._row_parallel_weights)
|
||||||
for p in self._column_parallel_weights:
|
|
||||||
if p in name:
|
|
||||||
shard_size = param.shape[0]
|
|
||||||
loaded_weight = loaded_weight[
|
|
||||||
shard_size * tensor_model_parallel_rank
|
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
|
||||||
break
|
|
||||||
for p in self._row_parallel_weights:
|
|
||||||
if p in name:
|
|
||||||
shard_size = param.shape[1]
|
|
||||||
loaded_weight = loaded_weight[
|
|
||||||
:,
|
|
||||||
shard_size * tensor_model_parallel_rank
|
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
|
||||||
break
|
|
||||||
|
|
||||||
assert param.shape == loaded_weight.shape
|
|
||||||
param.data.copy_(loaded_weight)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_weights(model_name: str, path: str):
|
|
||||||
path = os.path.join(path, f"{model_name}-np")
|
|
||||||
path = os.path.abspath(os.path.expanduser(path))
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
lock_path = os.path.join(path, "file_lock")
|
|
||||||
lock = filelock.FileLock(lock_path)
|
|
||||||
|
|
||||||
with lock:
|
|
||||||
test_weight_path = os.path.join(
|
|
||||||
path, "gpt_neox.embed_in.weight")
|
|
||||||
if os.path.exists(test_weight_path):
|
|
||||||
return path
|
|
||||||
|
|
||||||
folder = snapshot_download(model_name, allow_patterns="*.bin",
|
|
||||||
cache_dir=os.path.join(path, "cache"))
|
|
||||||
bin_files = glob.glob(os.path.join(folder, "*.bin"))
|
|
||||||
|
|
||||||
for bin_file in tqdm(bin_files, desc="Convert format"):
|
|
||||||
state = torch.load(bin_file, map_location="cpu")
|
|
||||||
for name, param in tqdm(state.items(), leave=False):
|
|
||||||
param_path = os.path.join(path, name)
|
|
||||||
with open(param_path, "wb") as f:
|
|
||||||
np.save(f, param.cpu().detach().numpy())
|
|
||||||
|
|
||||||
return path
|
|
||||||
|
|
||||||
def initialize_dummy_weights(self) -> None:
|
def initialize_dummy_weights(self) -> None:
|
||||||
for param in self.state_dict().values():
|
for param in self.state_dict().values():
|
||||||
|
@ -1,11 +1,6 @@
|
|||||||
"""1D LLaMA model compatible with HuggingFace weights."""
|
"""1D LLaMA model compatible with HuggingFace weights."""
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import filelock
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
@ -15,6 +10,8 @@ from cacheflow.models.activation import SiluAndMul
|
|||||||
from cacheflow.models.attention import LlamaCacheFlowAttention
|
from cacheflow.models.attention import LlamaCacheFlowAttention
|
||||||
from cacheflow.models.layernorm import RMSNorm
|
from cacheflow.models.layernorm import RMSNorm
|
||||||
from cacheflow.models.sample import Sampler
|
from cacheflow.models.sample import Sampler
|
||||||
|
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||||
|
load_tensor_parallel_weights)
|
||||||
from cacheflow.parallel_utils.parallel_state import (
|
from cacheflow.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
||||||
@ -216,76 +213,57 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
"up_proj.weight"]
|
"up_proj.weight"]
|
||||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||||
|
|
||||||
def load_weights(self, weights_path: str):
|
def load_weights(self, model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
use_np_cache: bool = False):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, param in state_dict.items():
|
|
||||||
if "qkv_proj" in name or "gate_up_proj" in name:
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
if "qkv_proj" in name:
|
model_name_or_path, cache_dir, use_np_cache):
|
||||||
original_name = "qkv_proj"
|
if "rotary_emb.inv_freq" in name:
|
||||||
weight_names = ["q_proj", "k_proj", "v_proj"]
|
continue
|
||||||
|
|
||||||
|
is_attention_weight = False
|
||||||
|
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
|
||||||
|
if att_weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||||
shard_size = param.shape[0] // 3
|
shard_size = param.shape[0] // 3
|
||||||
else:
|
loaded_weight = loaded_weight[
|
||||||
original_name = "gate_up_proj"
|
shard_size * tensor_model_parallel_rank
|
||||||
weight_names = ["gate_proj", "up_proj"]
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id
|
||||||
|
:shard_size * (stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_attention_weight = True
|
||||||
|
break
|
||||||
|
if is_attention_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_gate_up_weight = False
|
||||||
|
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||||
shard_size = param.shape[0] // 2
|
shard_size = param.shape[0] // 2
|
||||||
weights_to_concat = []
|
|
||||||
for weight_name in weight_names:
|
|
||||||
weight = np.load(os.path.join(
|
|
||||||
weights_path, name.replace(original_name, weight_name)))
|
|
||||||
weights_to_concat.append(weight[
|
|
||||||
shard_size * tensor_model_parallel_rank
|
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)])
|
|
||||||
loaded_weight = torch.from_numpy(
|
|
||||||
np.concatenate(weights_to_concat, axis=0))
|
|
||||||
else:
|
|
||||||
loaded_weight = torch.from_numpy(
|
|
||||||
np.load(os.path.join(weights_path, name)))
|
|
||||||
for p in self._column_parallel_weights:
|
|
||||||
if p in name:
|
|
||||||
shard_size = param.shape[0]
|
|
||||||
loaded_weight = loaded_weight[
|
loaded_weight = loaded_weight[
|
||||||
shard_size * tensor_model_parallel_rank
|
shard_size * tensor_model_parallel_rank
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id
|
||||||
|
:shard_size * (stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_gate_up_weight = True
|
||||||
break
|
break
|
||||||
for p in self._row_parallel_weights:
|
if is_gate_up_weight:
|
||||||
if p in name:
|
continue
|
||||||
shard_size = param.shape[1]
|
|
||||||
loaded_weight = loaded_weight[
|
|
||||||
:,
|
|
||||||
shard_size * tensor_model_parallel_rank
|
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
|
||||||
break
|
|
||||||
|
|
||||||
assert param.shape == loaded_weight.shape
|
param = state_dict[name]
|
||||||
param.data.copy_(loaded_weight)
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
|
self._column_parallel_weights,
|
||||||
@staticmethod
|
self._row_parallel_weights)
|
||||||
def get_weights(model_name: str, path: str):
|
|
||||||
if not os.path.isfile(os.path.join(model_name, "config.json")):
|
|
||||||
raise ValueError("LLaMA model's model_name has to be a path"
|
|
||||||
"to the huggingface model's directory.")
|
|
||||||
path = os.path.join(model_name, f"np")
|
|
||||||
path = os.path.abspath(os.path.expanduser(path))
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
lock_path = os.path.join(path, "file_lock")
|
|
||||||
lock = filelock.FileLock(lock_path)
|
|
||||||
|
|
||||||
with lock:
|
|
||||||
test_weight_path = os.path.join(path, "model.embed_tokens.weight")
|
|
||||||
if os.path.exists(test_weight_path):
|
|
||||||
return path
|
|
||||||
|
|
||||||
bin_files = glob.glob(os.path.join(model_name, "*.bin"))
|
|
||||||
|
|
||||||
for bin_file in tqdm(bin_files, desc="Convert format"):
|
|
||||||
state = torch.load(bin_file, map_location="cpu")
|
|
||||||
for name, param in tqdm(state.items(), leave=False):
|
|
||||||
param_path = os.path.join(path, name)
|
|
||||||
with open(param_path, "wb") as f:
|
|
||||||
np.save(f, param.cpu().detach().numpy())
|
|
||||||
|
|
||||||
return path
|
|
||||||
|
|
||||||
def initialize_dummy_weights(self) -> None:
|
def initialize_dummy_weights(self) -> None:
|
||||||
for param in self.state_dict().values():
|
for param in self.state_dict().values():
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Union
|
from typing import Union, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -32,8 +32,9 @@ _MEMORY_ANALYZERS = {
|
|||||||
def get_model(
|
def get_model(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
dtype: Union[torch.dtype, str],
|
dtype: Union[torch.dtype, str],
|
||||||
path: str,
|
cache_dir: Optional[str],
|
||||||
use_dummy_weights: bool,
|
use_dummy_weights: bool,
|
||||||
|
use_np_cache: bool,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
torch_dtype = get_torch_dtype(dtype)
|
torch_dtype = get_torch_dtype(dtype)
|
||||||
torch.set_default_dtype(torch_dtype)
|
torch.set_default_dtype(torch_dtype)
|
||||||
@ -49,12 +50,10 @@ def get_model(
|
|||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
model.initialize_dummy_weights()
|
model.initialize_dummy_weights()
|
||||||
else:
|
else:
|
||||||
# Download model weights if it's not cached.
|
|
||||||
weights_dir = model_class.get_weights(model_name, path=path)
|
|
||||||
# Create a model instance.
|
# Create a model instance.
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
# Load the weights from the cached or downloaded files.
|
# Load the weights from the cached or downloaded files.
|
||||||
model.load_weights(weights_dir)
|
model.load_weights(model_name, cache_dir, use_np_cache)
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
return model.eval(), torch_dtype
|
return model.eval(), torch_dtype
|
||||||
raise ValueError(f'Unsupported model name: {model_name}')
|
raise ValueError(f'Unsupported model name: {model_name}')
|
||||||
|
@ -1,19 +1,15 @@
|
|||||||
"""1D OPT model compatible with HuggingFace weights."""
|
"""1D OPT model compatible with HuggingFace weights."""
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import filelock
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import OPTConfig
|
from transformers import OPTConfig
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
from cacheflow.models import InputMetadata
|
from cacheflow.models import InputMetadata
|
||||||
from cacheflow.models.attention import OPTCacheFlowAttention
|
from cacheflow.models.attention import OPTCacheFlowAttention
|
||||||
from cacheflow.models.sample import Sampler
|
from cacheflow.models.sample import Sampler
|
||||||
|
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||||
|
load_tensor_parallel_weights)
|
||||||
from cacheflow.parallel_utils.parallel_state import (
|
from cacheflow.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
||||||
@ -257,73 +253,42 @@ class OPTForCausalLM(nn.Module):
|
|||||||
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
|
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
|
||||||
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
|
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
|
||||||
|
|
||||||
def load_weights(self, weights_path: str):
|
def load_weights(self, model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
use_np_cache: bool = False):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, param in state_dict.items():
|
|
||||||
if "lm_head_weight" in name:
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, use_np_cache):
|
||||||
|
if "lm_head.weight" in name:
|
||||||
continue
|
continue
|
||||||
if "qkv_proj" in name:
|
|
||||||
shard_size = param.shape[0] // 3
|
|
||||||
weights_to_concat = []
|
|
||||||
for weight_name in ["q_proj", "k_proj", "v_proj"]:
|
|
||||||
weight = np.load(os.path.join(
|
|
||||||
weights_path, name.replace("qkv_proj", weight_name)))
|
|
||||||
weights_to_concat.append(weight[
|
|
||||||
shard_size * tensor_model_parallel_rank
|
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)])
|
|
||||||
loaded_weight = torch.from_numpy(
|
|
||||||
np.concatenate(weights_to_concat, axis=0))
|
|
||||||
else:
|
|
||||||
loaded_weight = torch.from_numpy(
|
|
||||||
np.load(os.path.join(weights_path, name)))
|
|
||||||
for p in self._column_parallel_weights:
|
|
||||||
if p in name:
|
|
||||||
shard_size = param.shape[0]
|
|
||||||
loaded_weight = loaded_weight[
|
|
||||||
shard_size * tensor_model_parallel_rank
|
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
|
||||||
break
|
|
||||||
for p in self._row_parallel_weights:
|
|
||||||
if p in name:
|
|
||||||
shard_size = param.shape[1]
|
|
||||||
loaded_weight = loaded_weight[
|
|
||||||
:,
|
|
||||||
shard_size * tensor_model_parallel_rank
|
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
|
||||||
break
|
|
||||||
|
|
||||||
assert param.shape == loaded_weight.shape
|
|
||||||
param.data.copy_(loaded_weight)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_weights(model_name: str, path: str):
|
|
||||||
path = os.path.join(path, f"{model_name}-np")
|
|
||||||
path = os.path.abspath(os.path.expanduser(path))
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
lock_path = os.path.join(path, "file_lock")
|
|
||||||
lock = filelock.FileLock(lock_path)
|
|
||||||
|
|
||||||
with lock:
|
|
||||||
test_weight_path = os.path.join(
|
|
||||||
path, "model.decoder.embed_positions.weight")
|
|
||||||
if os.path.exists(test_weight_path):
|
|
||||||
return path
|
|
||||||
|
|
||||||
folder = snapshot_download(model_name, allow_patterns="*.bin",
|
|
||||||
cache_dir=os.path.join(path, "cache"))
|
|
||||||
bin_files = glob.glob(os.path.join(folder, "*.bin"))
|
|
||||||
|
|
||||||
for bin_file in tqdm(bin_files, desc="Convert format"):
|
|
||||||
state = torch.load(bin_file, map_location="cpu")
|
|
||||||
for name, param in tqdm(state.items(), leave=False):
|
|
||||||
if name.startswith("decoder."):
|
if name.startswith("decoder."):
|
||||||
name = "model." + name
|
name = "model." + name
|
||||||
param_path = os.path.join(path, name)
|
|
||||||
with open(param_path, "wb") as f:
|
|
||||||
np.save(f, param.cpu().detach().numpy())
|
|
||||||
|
|
||||||
return path
|
is_attention_weight = False
|
||||||
|
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
|
||||||
|
if att_weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||||
|
shard_size = param.shape[0] // 3
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank
|
||||||
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id
|
||||||
|
:shard_size * (stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_attention_weight = True
|
||||||
|
break
|
||||||
|
if is_attention_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights)
|
||||||
|
|
||||||
def initialize_dummy_weights(self) -> None:
|
def initialize_dummy_weights(self) -> None:
|
||||||
for param in self.state_dict().values():
|
for param in self.state_dict().values():
|
||||||
|
@ -1,6 +1,16 @@
|
|||||||
from typing import Union
|
import os
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import filelock
|
||||||
|
from typing import Union, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from cacheflow.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank)
|
||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
'half': torch.half,
|
'half': torch.half,
|
||||||
@ -22,3 +32,86 @@ def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
|
|||||||
torch_dtype = get_torch_dtype(dtype)
|
torch_dtype = get_torch_dtype(dtype)
|
||||||
return torch.tensor([], dtype=torch_dtype).element_size()
|
return torch.tensor([], dtype=torch_dtype).element_size()
|
||||||
|
|
||||||
|
|
||||||
|
class Disabledtqdm(tqdm):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs, disable=True)
|
||||||
|
|
||||||
|
|
||||||
|
def hf_model_weights_iterator(model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
use_np_cache: bool = False):
|
||||||
|
# Prepare file lock directory to prevent multiple processes from
|
||||||
|
# downloading the same model weights at the same time.
|
||||||
|
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
||||||
|
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
|
||||||
|
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
|
||||||
|
|
||||||
|
# Download model weights from huggingface.
|
||||||
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
|
if not is_local:
|
||||||
|
with lock:
|
||||||
|
hf_folder = snapshot_download(model_name_or_path,
|
||||||
|
allow_patterns="*.bin",
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
tqdm_class=Disabledtqdm)
|
||||||
|
else:
|
||||||
|
hf_folder = model_name_or_path
|
||||||
|
|
||||||
|
hf_bin_files = glob.glob(os.path.join(hf_folder, "*.bin"))
|
||||||
|
|
||||||
|
if use_np_cache:
|
||||||
|
# Convert the model weights from torch tensors to numpy arrays for
|
||||||
|
# faster loading.
|
||||||
|
np_folder = os.path.join(hf_folder, 'np')
|
||||||
|
os.makedirs(np_folder, exist_ok=True)
|
||||||
|
weight_names_file = os.path.join(np_folder, 'weight_names.json')
|
||||||
|
with lock:
|
||||||
|
if not os.path.exists(weight_names_file):
|
||||||
|
weight_names = []
|
||||||
|
for bin_file in hf_bin_files:
|
||||||
|
state = torch.load(bin_file, map_location="cpu")
|
||||||
|
for name, param in state.items():
|
||||||
|
param_path = os.path.join(np_folder, name)
|
||||||
|
with open(param_path, "wb") as f:
|
||||||
|
np.save(f, param.cpu().detach().numpy())
|
||||||
|
weight_names.append(name)
|
||||||
|
with open(weight_names_file, 'w') as f:
|
||||||
|
json.dump(weight_names, f)
|
||||||
|
|
||||||
|
with open(weight_names_file, 'r') as f:
|
||||||
|
weight_names = json.load(f)
|
||||||
|
|
||||||
|
for name in weight_names:
|
||||||
|
param_path = os.path.join(np_folder, name)
|
||||||
|
with open(param_path, "rb") as f:
|
||||||
|
param = np.load(f)
|
||||||
|
yield name, torch.from_numpy(param)
|
||||||
|
else:
|
||||||
|
for bin_file in hf_bin_files:
|
||||||
|
state = torch.load(bin_file, map_location="cpu")
|
||||||
|
for name, param in state.items():
|
||||||
|
yield name, param
|
||||||
|
|
||||||
|
|
||||||
|
def load_tensor_parallel_weights(param, loaded_weight, param_name,
|
||||||
|
column_parallel_weight_names,
|
||||||
|
row_parallel_weight_names):
|
||||||
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
|
for p in column_parallel_weight_names:
|
||||||
|
if p in param_name:
|
||||||
|
shard_size = param.shape[0]
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank
|
||||||
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
|
break
|
||||||
|
for p in row_parallel_weight_names:
|
||||||
|
if p in param_name:
|
||||||
|
shard_size = param.shape[1]
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
:,
|
||||||
|
shard_size * tensor_model_parallel_rank
|
||||||
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
|
break
|
||||||
|
assert param.shape == loaded_weight.shape
|
||||||
|
param.data.copy_(loaded_weight)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Union, Tuple
|
from typing import Dict, List, Union, Tuple, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
@ -29,8 +29,9 @@ class Controller:
|
|||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
model_path: str,
|
cache_dir: Optional[str],
|
||||||
use_dummy_weights: bool,
|
use_dummy_weights: bool,
|
||||||
|
use_np_cache: bool,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
use_ray: bool,
|
use_ray: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -66,8 +67,9 @@ class Controller:
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
pipeline_parallel_size=pipeline_parallel_size,
|
pipeline_parallel_size=pipeline_parallel_size,
|
||||||
model_path=model_path,
|
cache_dir=cache_dir,
|
||||||
use_dummy_weights=use_dummy_weights,
|
use_dummy_weights=use_dummy_weights,
|
||||||
|
use_np_cache=use_np_cache,
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
)
|
)
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -28,8 +28,9 @@ class Worker:
|
|||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
model_path: str,
|
cache_dir: Optional[str],
|
||||||
use_dummy_weights: bool,
|
use_dummy_weights: bool,
|
||||||
|
use_np_cache: bool,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
pipeline_parallel_size: int = 1,
|
pipeline_parallel_size: int = 1,
|
||||||
@ -45,7 +46,8 @@ class Worker:
|
|||||||
|
|
||||||
# Initialize the model.
|
# Initialize the model.
|
||||||
self.model, self.dtype = get_model(
|
self.model, self.dtype = get_model(
|
||||||
model_name, dtype=dtype, path=model_path, use_dummy_weights=use_dummy_weights)
|
model_name, dtype=dtype, cache_dir=cache_dir,
|
||||||
|
use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache)
|
||||||
tensor_model_parallel_world_size = (
|
tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
initialize_all_reduce_launcher(
|
initialize_all_reduce_launcher(
|
||||||
|
@ -1,53 +1,13 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from cacheflow.master.simple_frontend import SimpleFrontend
|
from cacheflow.master.server import (
|
||||||
from cacheflow.master.server import (Server, add_server_arguments,
|
add_server_arguments, process_server_arguments,
|
||||||
process_server_arguments,
|
init_local_server_and_frontend_with_arguments)
|
||||||
initialize_cluster)
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
# TODO(zhuohan): Support pipeline parallelism.
|
server, frontend = init_local_server_and_frontend_with_arguments(args)
|
||||||
assert args.pipeline_parallel_size == 1, (
|
|
||||||
'Pipeline parallelism is not supported yet.')
|
|
||||||
|
|
||||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
|
||||||
all_stage_devices) = (
|
|
||||||
initialize_cluster(
|
|
||||||
use_ray=args.use_ray,
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size))
|
|
||||||
|
|
||||||
# Create a server.
|
|
||||||
server = Server(
|
|
||||||
model=args.model,
|
|
||||||
model_path=args.model_path,
|
|
||||||
use_dummy_weights=args.use_dummy_weights,
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
|
||||||
block_size=args.block_size,
|
|
||||||
dtype=args.dtype,
|
|
||||||
seed=args.seed,
|
|
||||||
swap_space=args.swap_space,
|
|
||||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
|
||||||
max_num_sequences=args.max_num_sequences,
|
|
||||||
num_nodes=num_nodes,
|
|
||||||
num_devices_per_node=num_devices_per_node,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
all_stage_devices=all_stage_devices,
|
|
||||||
gpu_memory=get_gpu_memory(),
|
|
||||||
cpu_memory=get_cpu_memory(),
|
|
||||||
use_ray=args.use_ray,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a frontend.
|
|
||||||
frontend = SimpleFrontend(
|
|
||||||
model_name=args.model,
|
|
||||||
block_size=args.block_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test the following inputs.
|
# Test the following inputs.
|
||||||
test_inputs = [
|
test_inputs = [
|
||||||
('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}),
|
('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user