Add an option to launch cacheflow without ray (#51)
This commit is contained in:
parent
a96d63c21d
commit
4858f3bb45
3
.gitignore
vendored
3
.gitignore
vendored
@ -3,8 +3,11 @@
|
|||||||
*.egg-info/
|
*.egg-info/
|
||||||
*.eggs/
|
*.eggs/
|
||||||
*.so
|
*.so
|
||||||
|
*.log
|
||||||
|
*.csv
|
||||||
build/
|
build/
|
||||||
|
|
||||||
*.pkl
|
*.pkl
|
||||||
*.png
|
*.png
|
||||||
**/log.txt
|
**/log.txt
|
||||||
|
.vscode/
|
||||||
|
@ -8,7 +8,8 @@ import torch
|
|||||||
|
|
||||||
from cacheflow.master.simple_frontend import SimpleFrontend
|
from cacheflow.master.simple_frontend import SimpleFrontend
|
||||||
from cacheflow.master.server import (Server, add_server_arguments,
|
from cacheflow.master.server import (Server, add_server_arguments,
|
||||||
initialize_ray_cluster)
|
process_server_arguments,
|
||||||
|
initialize_cluster)
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
||||||
|
|
||||||
@ -20,8 +21,8 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
(num_nodes, num_devices_per_node, distributed_init_method,
|
||||||
all_stage_devices) = (
|
all_stage_devices) = (
|
||||||
initialize_ray_cluster(
|
initialize_cluster(
|
||||||
address='local',
|
use_ray=args.use_ray,
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||||
tensor_parallel_size=args.tensor_parallel_size))
|
tensor_parallel_size=args.tensor_parallel_size))
|
||||||
|
|
||||||
@ -44,6 +45,7 @@ def main(args: argparse.Namespace):
|
|||||||
all_stage_devices=all_stage_devices,
|
all_stage_devices=all_stage_devices,
|
||||||
gpu_memory=get_gpu_memory(),
|
gpu_memory=get_gpu_memory(),
|
||||||
cpu_memory=get_cpu_memory(),
|
cpu_memory=get_cpu_memory(),
|
||||||
|
use_ray=args.use_ray,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a frontend.
|
# Create a frontend.
|
||||||
@ -91,7 +93,8 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Benchmark the latency of decoding a single sentence.')
|
||||||
parser = add_server_arguments(parser)
|
parser = add_server_arguments(parser)
|
||||||
parser.add_argument('--input-len', type=int, default=32)
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
parser.add_argument('--output-len', type=int, default=128)
|
parser.add_argument('--output-len', type=int, default=128)
|
||||||
@ -99,6 +102,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--n', type=int, default=1)
|
parser.add_argument('--n', type=int, default=1)
|
||||||
parser.add_argument('--use-beam-search', action='store_true')
|
parser.add_argument('--use-beam-search', action='store_true')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
args = process_server_arguments(args)
|
||||||
args.max_num_batched_tokens = max(
|
args.max_num_batched_tokens = max(
|
||||||
args.max_num_batched_tokens, args.batch_size * args.input_len)
|
args.max_num_batched_tokens, args.batch_size * args.input_len)
|
||||||
print(args)
|
print(args)
|
||||||
|
@ -11,7 +11,8 @@ from transformers import AutoConfig
|
|||||||
from benchmark.trace import generate_text_completion_requests
|
from benchmark.trace import generate_text_completion_requests
|
||||||
from cacheflow.master.simple_frontend import SimpleFrontend
|
from cacheflow.master.simple_frontend import SimpleFrontend
|
||||||
from cacheflow.master.server import (Server, add_server_arguments,
|
from cacheflow.master.server import (Server, add_server_arguments,
|
||||||
initialize_ray_cluster)
|
process_server_arguments,
|
||||||
|
initialize_cluster)
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
||||||
|
|
||||||
@ -25,8 +26,8 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
(num_nodes, num_devices_per_node, distributed_init_method,
|
||||||
all_stage_devices) = (
|
all_stage_devices) = (
|
||||||
initialize_ray_cluster(
|
initialize_cluster(
|
||||||
address='local',
|
use_ray=args.use_ray,
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||||
tensor_parallel_size=args.tensor_parallel_size))
|
tensor_parallel_size=args.tensor_parallel_size))
|
||||||
|
|
||||||
@ -49,6 +50,7 @@ def main(args: argparse.Namespace):
|
|||||||
all_stage_devices=all_stage_devices,
|
all_stage_devices=all_stage_devices,
|
||||||
gpu_memory=get_gpu_memory(),
|
gpu_memory=get_gpu_memory(),
|
||||||
cpu_memory=get_cpu_memory(),
|
cpu_memory=get_cpu_memory(),
|
||||||
|
use_ray=args.use_ray,
|
||||||
collect_stats=True,
|
collect_stats=True,
|
||||||
do_memory_analysis=args.do_memory_analysis,
|
do_memory_analysis=args.do_memory_analysis,
|
||||||
)
|
)
|
||||||
@ -134,7 +136,7 @@ def main(args: argparse.Namespace):
|
|||||||
finished.append({
|
finished.append({
|
||||||
'group_id': seq_group.group_id,
|
'group_id': seq_group.group_id,
|
||||||
'seq_id': seq.seq_id,
|
'seq_id': seq.seq_id,
|
||||||
'arrival_time': arrival_time,
|
'arrival_time': arrival_time,
|
||||||
'finish_time': finish_time,
|
'finish_time': finish_time,
|
||||||
'prompt_len': seq.prompt_len,
|
'prompt_len': seq.prompt_len,
|
||||||
'output_len': output_len,
|
'output_len': output_len,
|
||||||
@ -225,8 +227,9 @@ def get_sampling_dir_name(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
|
parser = argparse.ArgumentParser(
|
||||||
parser = add_server_arguments(parser)
|
description='Benchmark the performance on a series of requests.')
|
||||||
|
parser = add_server_arguments(parser)
|
||||||
parser.add_argument('--output-dir', type=str, help='path to output directory', default=None)
|
parser.add_argument('--output-dir', type=str, help='path to output directory', default=None)
|
||||||
|
|
||||||
parser.add_argument('--dataset', type=str, help='path to dataset', required=True)
|
parser.add_argument('--dataset', type=str, help='path to dataset', required=True)
|
||||||
@ -246,6 +249,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--n6-beam', type=float, help='ratio of requests with n=6 & beam search', default=0.0)
|
parser.add_argument('--n6-beam', type=float, help='ratio of requests with n=6 & beam search', default=0.0)
|
||||||
parser.add_argument('--n8-beam', type=float, help='ratio of requests with n=8 & beam search', default=0.0)
|
parser.add_argument('--n8-beam', type=float, help='ratio of requests with n=8 & beam search', default=0.0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
args = process_server_arguments(args)
|
||||||
if args.n1 + args.n2 + args.n3 + args.n4 + args.n6 + args.n2_beam + args.n4_beam + args.n6_beam + args.n8_beam != 1.0:
|
if args.n1 + args.n2 + args.n3 + args.n4 + args.n6 + args.n2_beam + args.n4_beam + args.n6_beam + args.n8_beam != 1.0:
|
||||||
raise ValueError('The ratios of requests must sum to 1.')
|
raise ValueError('The ratios of requests must sum to 1.')
|
||||||
|
|
||||||
|
@ -13,7 +13,8 @@ import uvicorn
|
|||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.sequence import Sequence, SequenceGroup
|
from cacheflow.sequence import Sequence, SequenceGroup
|
||||||
from cacheflow.master.server import (Server, add_server_arguments,
|
from cacheflow.master.server import (Server, add_server_arguments,
|
||||||
initialize_ray_cluster)
|
process_server_arguments,
|
||||||
|
initialize_cluster)
|
||||||
from cacheflow.worker.controller import DeviceID
|
from cacheflow.worker.controller import DeviceID
|
||||||
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
|
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
|
||||||
|
|
||||||
@ -33,17 +34,22 @@ class FastAPIFrontend:
|
|||||||
seed: int,
|
seed: int,
|
||||||
swap_space: int,
|
swap_space: int,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
|
max_num_sequences: int,
|
||||||
num_nodes: int,
|
num_nodes: int,
|
||||||
num_devices_per_node: int,
|
num_devices_per_node: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
all_stage_devices: List[List[DeviceID]],
|
all_stage_devices: List[List[DeviceID]],
|
||||||
|
server_use_ray: bool,
|
||||||
):
|
):
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
self.seq_group_counter = Counter()
|
self.seq_group_counter = Counter()
|
||||||
self.seq_counter = Counter()
|
self.seq_counter = Counter()
|
||||||
remote_server_class = ray.remote(num_cpus=0)(Server)
|
if server_use_ray:
|
||||||
|
remote_server_class = ray.remote(num_cpus=0)(Server)
|
||||||
|
else:
|
||||||
|
remote_server_class = ray.remote(num_gpus=1)(Server)
|
||||||
self.server = remote_server_class.remote(
|
self.server = remote_server_class.remote(
|
||||||
model=model,
|
model=model,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
@ -55,12 +61,14 @@ class FastAPIFrontend:
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
swap_space=swap_space,
|
swap_space=swap_space,
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_num_sequences=max_num_sequences,
|
||||||
num_nodes=num_nodes,
|
num_nodes=num_nodes,
|
||||||
num_devices_per_node=num_devices_per_node,
|
num_devices_per_node=num_devices_per_node,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
all_stage_devices=all_stage_devices,
|
all_stage_devices=all_stage_devices,
|
||||||
gpu_memory=get_gpu_memory(),
|
gpu_memory=get_gpu_memory(),
|
||||||
cpu_memory=get_cpu_memory(),
|
cpu_memory=get_cpu_memory(),
|
||||||
|
use_ray=server_use_ray,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.running_seq_groups: Dict[int, SequenceGroup] = {}
|
self.running_seq_groups: Dict[int, SequenceGroup] = {}
|
||||||
@ -149,6 +157,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--port", type=int, default=10002)
|
parser.add_argument("--port", type=int, default=10002)
|
||||||
parser = add_server_arguments(parser)
|
parser = add_server_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
args = process_server_arguments(args)
|
||||||
|
|
||||||
# TODO(zhuohan): Support pipeline parallelism.
|
# TODO(zhuohan): Support pipeline parallelism.
|
||||||
assert args.pipeline_parallel_size == 1, (
|
assert args.pipeline_parallel_size == 1, (
|
||||||
@ -156,7 +165,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
(num_nodes, num_devices_per_node, distributed_init_method,
|
||||||
all_stage_devices) = (
|
all_stage_devices) = (
|
||||||
initialize_ray_cluster(
|
initialize_cluster(
|
||||||
|
use_ray=True,
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||||
tensor_parallel_size=args.tensor_parallel_size))
|
tensor_parallel_size=args.tensor_parallel_size))
|
||||||
|
|
||||||
@ -170,10 +180,12 @@ if __name__ == "__main__":
|
|||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
swap_space=args.swap_space,
|
swap_space=args.swap_space,
|
||||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||||
|
max_num_sequences=args.max_num_sequences,
|
||||||
num_nodes=num_nodes,
|
num_nodes=num_nodes,
|
||||||
num_devices_per_node=num_devices_per_node,
|
num_devices_per_node=num_devices_per_node,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
all_stage_devices=all_stage_devices,
|
all_stage_devices=all_stage_devices,
|
||||||
|
server_use_ray=args.use_ray,
|
||||||
)
|
)
|
||||||
|
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple, Optional
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import ray
|
import torch
|
||||||
|
try:
|
||||||
|
import ray
|
||||||
|
except ImportError:
|
||||||
|
ray = None
|
||||||
|
|
||||||
from cacheflow.master.scheduler import Scheduler
|
from cacheflow.master.scheduler import Scheduler
|
||||||
from cacheflow.models import get_memory_analyzer
|
from cacheflow.models import get_memory_analyzer
|
||||||
@ -31,6 +35,7 @@ class Server:
|
|||||||
all_stage_devices: List[List[DeviceID]],
|
all_stage_devices: List[List[DeviceID]],
|
||||||
gpu_memory: int,
|
gpu_memory: int,
|
||||||
cpu_memory: int,
|
cpu_memory: int,
|
||||||
|
use_ray: bool,
|
||||||
collect_stats: bool = False,
|
collect_stats: bool = False,
|
||||||
do_memory_analysis: bool = False,
|
do_memory_analysis: bool = False,
|
||||||
):
|
):
|
||||||
@ -38,6 +43,10 @@ class Server:
|
|||||||
self.num_devices_per_node = num_devices_per_node
|
self.num_devices_per_node = num_devices_per_node
|
||||||
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
||||||
|
|
||||||
|
if not use_ray:
|
||||||
|
assert self.world_size == 1, (
|
||||||
|
"Only support single GPU without Ray.")
|
||||||
|
|
||||||
self.memory_analyzer = get_memory_analyzer(
|
self.memory_analyzer = get_memory_analyzer(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
@ -72,6 +81,7 @@ class Server:
|
|||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
use_dummy_weights=use_dummy_weights,
|
use_dummy_weights=use_dummy_weights,
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
use_ray=use_ray,
|
||||||
)
|
)
|
||||||
self.controllers.append(controller)
|
self.controllers.append(controller)
|
||||||
|
|
||||||
@ -105,11 +115,30 @@ class Server:
|
|||||||
self.scheduler.swapped)
|
self.scheduler.swapped)
|
||||||
|
|
||||||
|
|
||||||
def initialize_ray_cluster(
|
def initialize_cluster(
|
||||||
address: str = 'auto',
|
use_ray: bool = False,
|
||||||
|
address: Optional[str] = None,
|
||||||
pipeline_parallel_size: int = 1,
|
pipeline_parallel_size: int = 1,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
) -> Tuple[int, int, str, List[List[DeviceID]]]:
|
) -> Tuple[int, int, str, List[List[DeviceID]]]:
|
||||||
|
# Initialize cluster locally.
|
||||||
|
if not use_ray:
|
||||||
|
assert pipeline_parallel_size * tensor_parallel_size == 1, (
|
||||||
|
"Only support single GPU without Ray.")
|
||||||
|
num_nodes = 1
|
||||||
|
num_devices_per_node = torch.cuda.device_count()
|
||||||
|
port = random.randint(10000, 20000)
|
||||||
|
# We need to setup the distributed init method to make sure
|
||||||
|
# the distributed megatron code (e.g., get world size) works correctly.
|
||||||
|
distributed_init_method = f"tcp://localhost:{port}"
|
||||||
|
all_stage_devices = [[(0, None, 0)]]
|
||||||
|
return (num_nodes, num_devices_per_node, distributed_init_method,
|
||||||
|
all_stage_devices)
|
||||||
|
|
||||||
|
assert ray is not None, (
|
||||||
|
"Ray is not installed. Please install Ray to use distributed "
|
||||||
|
"serving.")
|
||||||
|
|
||||||
# Connect to a ray cluster.
|
# Connect to a ray cluster.
|
||||||
ray.init(address=address)
|
ray.init(address=address)
|
||||||
|
|
||||||
@ -177,6 +206,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
|
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
|
||||||
help='model path to download and load the weights')
|
help='model path to download and load the weights')
|
||||||
# Parallel arguments
|
# Parallel arguments
|
||||||
|
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
|
||||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
||||||
# KV cache arguments
|
# KV cache arguments
|
||||||
@ -190,3 +220,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
|
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
|
||||||
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
|
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
def process_server_arguments(args: argparse.Namespace):
|
||||||
|
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
|
||||||
|
args.use_ray = True
|
||||||
|
return args
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
from typing import Dict, List, Union, Tuple
|
from typing import Dict, List, Union, Tuple
|
||||||
|
|
||||||
import ray
|
try:
|
||||||
|
import ray
|
||||||
|
except ImportError:
|
||||||
|
ray = None
|
||||||
|
|
||||||
from cacheflow.master.scheduler import Scheduler
|
from cacheflow.master.scheduler import Scheduler
|
||||||
from cacheflow.sequence import SequenceGroupInputs
|
from cacheflow.sequence import SequenceGroupInputs
|
||||||
@ -29,6 +32,7 @@ class Controller:
|
|||||||
model_path: str,
|
model_path: str,
|
||||||
use_dummy_weights: bool,
|
use_dummy_weights: bool,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
|
use_ray: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.stage_id = stage_id
|
self.stage_id = stage_id
|
||||||
self.stage_devices = stage_devices
|
self.stage_devices = stage_devices
|
||||||
@ -36,6 +40,7 @@ class Controller:
|
|||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
self.num_gpu_blocks = num_gpu_blocks
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
self.num_cpu_blocks = num_cpu_blocks
|
||||||
|
self.use_ray = use_ray
|
||||||
|
|
||||||
# Which pipeline stage is this node assigned to?
|
# Which pipeline stage is this node assigned to?
|
||||||
self.is_first_stage = stage_id == 0
|
self.is_first_stage = stage_id == 0
|
||||||
@ -43,10 +48,13 @@ class Controller:
|
|||||||
|
|
||||||
self.workers: List[Worker] = []
|
self.workers: List[Worker] = []
|
||||||
for rank, node_resource, device_id in stage_devices:
|
for rank, node_resource, device_id in stage_devices:
|
||||||
worker_cls = ray.remote(num_cpus=0,
|
if self.use_ray:
|
||||||
num_gpus=1,
|
worker_cls = ray.remote(num_cpus=0,
|
||||||
resources={node_resource: 1e-5})(Worker)
|
num_gpus=1,
|
||||||
worker = worker_cls.remote(
|
resources={node_resource: 1e-5})(Worker).remote
|
||||||
|
else:
|
||||||
|
worker_cls = Worker
|
||||||
|
worker = worker_cls(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
@ -78,17 +86,21 @@ class Controller:
|
|||||||
blocks_to_swap_out: Dict[int, int],
|
blocks_to_swap_out: Dict[int, int],
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
futures = []
|
all_outputs = []
|
||||||
for worker in self.workers:
|
for worker in self.workers:
|
||||||
future = worker.execute_stage.remote(
|
executor = (worker.execute_stage.remote
|
||||||
|
if self.use_ray else worker.execute_stage)
|
||||||
|
output = executor(
|
||||||
input_seq_groups,
|
input_seq_groups,
|
||||||
blocks_to_swap_in,
|
blocks_to_swap_in,
|
||||||
blocks_to_swap_out,
|
blocks_to_swap_out,
|
||||||
blocks_to_copy,
|
blocks_to_copy,
|
||||||
)
|
)
|
||||||
futures.append(future)
|
all_outputs.append(output)
|
||||||
|
|
||||||
|
if self.use_ray:
|
||||||
|
all_outputs = ray.get(all_outputs)
|
||||||
|
|
||||||
all_outputs = ray.get(futures)
|
|
||||||
# Make sure all workers have the same results.
|
# Make sure all workers have the same results.
|
||||||
output = all_outputs[0]
|
output = all_outputs[0]
|
||||||
for other_output in all_outputs[1:]:
|
for other_output in all_outputs[1:]:
|
||||||
|
@ -3,7 +3,8 @@ from typing import List
|
|||||||
|
|
||||||
from cacheflow.master.simple_frontend import SimpleFrontend
|
from cacheflow.master.simple_frontend import SimpleFrontend
|
||||||
from cacheflow.master.server import (Server, add_server_arguments,
|
from cacheflow.master.server import (Server, add_server_arguments,
|
||||||
initialize_ray_cluster)
|
process_server_arguments,
|
||||||
|
initialize_cluster)
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
||||||
|
|
||||||
@ -14,7 +15,8 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
(num_nodes, num_devices_per_node, distributed_init_method,
|
||||||
all_stage_devices) = (
|
all_stage_devices) = (
|
||||||
initialize_ray_cluster(
|
initialize_cluster(
|
||||||
|
use_ray=args.use_ray,
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||||
tensor_parallel_size=args.tensor_parallel_size))
|
tensor_parallel_size=args.tensor_parallel_size))
|
||||||
|
|
||||||
@ -37,6 +39,7 @@ def main(args: argparse.Namespace):
|
|||||||
all_stage_devices=all_stage_devices,
|
all_stage_devices=all_stage_devices,
|
||||||
gpu_memory=get_gpu_memory(),
|
gpu_memory=get_gpu_memory(),
|
||||||
cpu_memory=get_cpu_memory(),
|
cpu_memory=get_cpu_memory(),
|
||||||
|
use_ray=args.use_ray,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a frontend.
|
# Create a frontend.
|
||||||
@ -70,4 +73,5 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
|
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
|
||||||
parser = add_server_arguments(parser)
|
parser = add_server_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
args = process_server_arguments(args)
|
||||||
main(args)
|
main(args)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user