2023-03-29 14:48:56 +08:00
|
|
|
import argparse
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from cacheflow.master.simple_frontend import SimpleFrontend
|
|
|
|
from cacheflow.master.server import (Server, add_server_arguments,
|
2023-04-30 15:42:17 +08:00
|
|
|
process_server_arguments,
|
|
|
|
initialize_cluster)
|
2023-03-29 14:48:56 +08:00
|
|
|
from cacheflow.sampling_params import SamplingParams
|
|
|
|
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
|
|
|
|
|
|
|
def main(args: argparse.Namespace):
|
|
|
|
# TODO(zhuohan): Support pipeline parallelism.
|
|
|
|
assert args.pipeline_parallel_size == 1, (
|
|
|
|
'Pipeline parallelism is not supported yet.')
|
|
|
|
|
|
|
|
(num_nodes, num_devices_per_node, distributed_init_method,
|
|
|
|
all_stage_devices) = (
|
2023-04-30 15:42:17 +08:00
|
|
|
initialize_cluster(
|
|
|
|
use_ray=args.use_ray,
|
2023-03-29 14:48:56 +08:00
|
|
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
|
|
tensor_parallel_size=args.tensor_parallel_size))
|
|
|
|
|
|
|
|
# Create a server.
|
|
|
|
server = Server(
|
|
|
|
model=args.model,
|
|
|
|
model_path=args.model_path,
|
2023-04-08 23:36:12 -07:00
|
|
|
use_dummy_weights=args.use_dummy_weights,
|
2023-03-29 14:48:56 +08:00
|
|
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
|
|
tensor_parallel_size=args.tensor_parallel_size,
|
|
|
|
block_size=args.block_size,
|
|
|
|
dtype=args.dtype,
|
|
|
|
seed=args.seed,
|
|
|
|
swap_space=args.swap_space,
|
2023-04-05 11:16:57 -07:00
|
|
|
max_num_batched_tokens=args.max_num_batched_tokens,
|
2023-04-12 15:03:49 -07:00
|
|
|
max_num_sequences=args.max_num_sequences,
|
2023-03-29 14:48:56 +08:00
|
|
|
num_nodes=num_nodes,
|
|
|
|
num_devices_per_node=num_devices_per_node,
|
|
|
|
distributed_init_method=distributed_init_method,
|
|
|
|
all_stage_devices=all_stage_devices,
|
|
|
|
gpu_memory=get_gpu_memory(),
|
|
|
|
cpu_memory=get_cpu_memory(),
|
2023-04-30 15:42:17 +08:00
|
|
|
use_ray=args.use_ray,
|
2023-03-29 14:48:56 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
# Create a frontend.
|
|
|
|
frontend = SimpleFrontend(
|
|
|
|
model_name=args.model,
|
|
|
|
block_size=args.block_size,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Test the following inputs.
|
|
|
|
test_inputs = [
|
|
|
|
('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}),
|
|
|
|
('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}),
|
|
|
|
('The future of cloud computing is', {}), # Use default parameters.
|
|
|
|
]
|
|
|
|
while True:
|
|
|
|
if test_inputs:
|
|
|
|
text, sampling_params_dict = test_inputs.pop(0)
|
|
|
|
sampling_params = SamplingParams.from_dict(sampling_params_dict)
|
|
|
|
sampling_params = frontend.add_eos_token(sampling_params)
|
|
|
|
frontend.query(text, sampling_params)
|
|
|
|
server.add_sequence_groups(frontend.get_inputs())
|
|
|
|
updated_seq_groups = server.step()
|
|
|
|
for seq_group in updated_seq_groups:
|
|
|
|
if seq_group.is_finished():
|
|
|
|
frontend.print_response(seq_group)
|
|
|
|
if not (server.has_unfinished_requests() or test_inputs):
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
|
|
|
|
parser = add_server_arguments(parser)
|
|
|
|
args = parser.parse_args()
|
2023-04-30 15:42:17 +08:00
|
|
|
args = process_server_arguments(args)
|
2023-03-29 14:48:56 +08:00
|
|
|
main(args)
|