Add an option to launch cacheflow without ray (#51)

This commit is contained in:
Zhuohan Li 2023-04-30 15:42:17 +08:00 committed by GitHub
parent a96d63c21d
commit 4858f3bb45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 102 additions and 28 deletions

3
.gitignore vendored
View File

@ -3,8 +3,11 @@
*.egg-info/
*.eggs/
*.so
*.log
*.csv
build/
*.pkl
*.png
**/log.txt
.vscode/

View File

@ -8,7 +8,8 @@ import torch
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
process_server_arguments,
initialize_cluster)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
@ -20,8 +21,8 @@ def main(args: argparse.Namespace):
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
address='local',
initialize_cluster(
use_ray=args.use_ray,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
@ -44,6 +45,7 @@ def main(args: argparse.Namespace):
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
)
# Create a frontend.
@ -91,7 +93,8 @@ def main(args: argparse.Namespace):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
parser = argparse.ArgumentParser(
description='Benchmark the latency of decoding a single sentence.')
parser = add_server_arguments(parser)
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
@ -99,6 +102,7 @@ if __name__ == '__main__':
parser.add_argument('--n', type=int, default=1)
parser.add_argument('--use-beam-search', action='store_true')
args = parser.parse_args()
args = process_server_arguments(args)
args.max_num_batched_tokens = max(
args.max_num_batched_tokens, args.batch_size * args.input_len)
print(args)

View File

@ -11,7 +11,8 @@ from transformers import AutoConfig
from benchmark.trace import generate_text_completion_requests
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
process_server_arguments,
initialize_cluster)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
@ -25,8 +26,8 @@ def main(args: argparse.Namespace):
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
address='local',
initialize_cluster(
use_ray=args.use_ray,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
@ -49,6 +50,7 @@ def main(args: argparse.Namespace):
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
collect_stats=True,
do_memory_analysis=args.do_memory_analysis,
)
@ -134,7 +136,7 @@ def main(args: argparse.Namespace):
finished.append({
'group_id': seq_group.group_id,
'seq_id': seq.seq_id,
'arrival_time': arrival_time,
'arrival_time': arrival_time,
'finish_time': finish_time,
'prompt_len': seq.prompt_len,
'output_len': output_len,
@ -225,8 +227,9 @@ def get_sampling_dir_name(
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
parser = add_server_arguments(parser)
parser = argparse.ArgumentParser(
description='Benchmark the performance on a series of requests.')
parser = add_server_arguments(parser)
parser.add_argument('--output-dir', type=str, help='path to output directory', default=None)
parser.add_argument('--dataset', type=str, help='path to dataset', required=True)
@ -246,6 +249,7 @@ if __name__ == '__main__':
parser.add_argument('--n6-beam', type=float, help='ratio of requests with n=6 & beam search', default=0.0)
parser.add_argument('--n8-beam', type=float, help='ratio of requests with n=8 & beam search', default=0.0)
args = parser.parse_args()
args = process_server_arguments(args)
if args.n1 + args.n2 + args.n3 + args.n4 + args.n6 + args.n2_beam + args.n4_beam + args.n6_beam + args.n8_beam != 1.0:
raise ValueError('The ratios of requests must sum to 1.')

View File

@ -13,7 +13,8 @@ import uvicorn
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
process_server_arguments,
initialize_cluster)
from cacheflow.worker.controller import DeviceID
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
@ -33,17 +34,22 @@ class FastAPIFrontend:
seed: int,
swap_space: int,
max_num_batched_tokens: int,
max_num_sequences: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
server_use_ray: bool,
):
self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
remote_server_class = ray.remote(num_cpus=0)(Server)
if server_use_ray:
remote_server_class = ray.remote(num_cpus=0)(Server)
else:
remote_server_class = ray.remote(num_gpus=1)(Server)
self.server = remote_server_class.remote(
model=model,
model_path=model_path,
@ -55,12 +61,14 @@ class FastAPIFrontend:
seed=seed,
swap_space=swap_space,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=server_use_ray,
)
self.running_seq_groups: Dict[int, SequenceGroup] = {}
@ -149,6 +157,7 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=10002)
parser = add_server_arguments(parser)
args = parser.parse_args()
args = process_server_arguments(args)
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
@ -156,7 +165,8 @@ if __name__ == "__main__":
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
initialize_cluster(
use_ray=True,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
@ -170,10 +180,12 @@ if __name__ == "__main__":
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
server_use_ray=args.use_ray,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View File

@ -1,8 +1,12 @@
import argparse
from typing import List, Tuple
from typing import List, Tuple, Optional
import random
import ray
import torch
try:
import ray
except ImportError:
ray = None
from cacheflow.master.scheduler import Scheduler
from cacheflow.models import get_memory_analyzer
@ -31,6 +35,7 @@ class Server:
all_stage_devices: List[List[DeviceID]],
gpu_memory: int,
cpu_memory: int,
use_ray: bool,
collect_stats: bool = False,
do_memory_analysis: bool = False,
):
@ -38,6 +43,10 @@ class Server:
self.num_devices_per_node = num_devices_per_node
self.world_size = pipeline_parallel_size * tensor_parallel_size
if not use_ray:
assert self.world_size == 1, (
"Only support single GPU without Ray.")
self.memory_analyzer = get_memory_analyzer(
model_name=model,
block_size=block_size,
@ -72,6 +81,7 @@ class Server:
model_path=model_path,
use_dummy_weights=use_dummy_weights,
max_num_batched_tokens=max_num_batched_tokens,
use_ray=use_ray,
)
self.controllers.append(controller)
@ -105,11 +115,30 @@ class Server:
self.scheduler.swapped)
def initialize_ray_cluster(
address: str = 'auto',
def initialize_cluster(
use_ray: bool = False,
address: Optional[str] = None,
pipeline_parallel_size: int = 1,
tensor_parallel_size: int = 1,
) -> Tuple[int, int, str, List[List[DeviceID]]]:
# Initialize cluster locally.
if not use_ray:
assert pipeline_parallel_size * tensor_parallel_size == 1, (
"Only support single GPU without Ray.")
num_nodes = 1
num_devices_per_node = torch.cuda.device_count()
port = random.randint(10000, 20000)
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method = f"tcp://localhost:{port}"
all_stage_devices = [[(0, None, 0)]]
return (num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices)
assert ray is not None, (
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=address)
@ -177,6 +206,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
help='model path to download and load the weights')
# Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
# KV cache arguments
@ -190,3 +220,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
return parser
def process_server_arguments(args: argparse.Namespace):
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
args.use_ray = True
return args

View File

@ -1,6 +1,9 @@
from typing import Dict, List, Union, Tuple
import ray
try:
import ray
except ImportError:
ray = None
from cacheflow.master.scheduler import Scheduler
from cacheflow.sequence import SequenceGroupInputs
@ -29,6 +32,7 @@ class Controller:
model_path: str,
use_dummy_weights: bool,
max_num_batched_tokens: int,
use_ray: bool,
) -> None:
self.stage_id = stage_id
self.stage_devices = stage_devices
@ -36,6 +40,7 @@ class Controller:
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.use_ray = use_ray
# Which pipeline stage is this node assigned to?
self.is_first_stage = stage_id == 0
@ -43,10 +48,13 @@ class Controller:
self.workers: List[Worker] = []
for rank, node_resource, device_id in stage_devices:
worker_cls = ray.remote(num_cpus=0,
num_gpus=1,
resources={node_resource: 1e-5})(Worker)
worker = worker_cls.remote(
if self.use_ray:
worker_cls = ray.remote(num_cpus=0,
num_gpus=1,
resources={node_resource: 1e-5})(Worker).remote
else:
worker_cls = Worker
worker = worker_cls(
model_name=model_name,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
@ -78,17 +86,21 @@ class Controller:
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
futures = []
all_outputs = []
for worker in self.workers:
future = worker.execute_stage.remote(
executor = (worker.execute_stage.remote
if self.use_ray else worker.execute_stage)
output = executor(
input_seq_groups,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
)
futures.append(future)
all_outputs.append(output)
if self.use_ray:
all_outputs = ray.get(all_outputs)
all_outputs = ray.get(futures)
# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:

View File

@ -3,7 +3,8 @@ from typing import List
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
process_server_arguments,
initialize_cluster)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
@ -14,7 +15,8 @@ def main(args: argparse.Namespace):
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
initialize_cluster(
use_ray=args.use_ray,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
@ -37,6 +39,7 @@ def main(args: argparse.Namespace):
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
)
# Create a frontend.
@ -70,4 +73,5 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
parser = add_server_arguments(parser)
args = parser.parse_args()
args = process_server_arguments(args)
main(args)