72 lines
2.6 KiB
Python
72 lines
2.6 KiB
Python
![]() |
import argparse
|
||
|
from typing import List
|
||
|
|
||
|
from cacheflow.master.simple_frontend import SimpleFrontend
|
||
|
from cacheflow.master.server import (Server, add_server_arguments,
|
||
|
initialize_ray_cluster)
|
||
|
from cacheflow.sampling_params import SamplingParams
|
||
|
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
||
|
|
||
|
def main(args: argparse.Namespace):
|
||
|
# TODO(zhuohan): Support pipeline parallelism.
|
||
|
assert args.pipeline_parallel_size == 1, (
|
||
|
'Pipeline parallelism is not supported yet.')
|
||
|
|
||
|
(num_nodes, num_devices_per_node, distributed_init_method,
|
||
|
all_stage_devices) = (
|
||
|
initialize_ray_cluster(
|
||
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
||
|
tensor_parallel_size=args.tensor_parallel_size))
|
||
|
|
||
|
# Create a server.
|
||
|
server = Server(
|
||
|
model=args.model,
|
||
|
model_path=args.model_path,
|
||
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
||
|
tensor_parallel_size=args.tensor_parallel_size,
|
||
|
block_size=args.block_size,
|
||
|
dtype=args.dtype,
|
||
|
seed=args.seed,
|
||
|
swap_space=args.swap_space,
|
||
|
max_batch_size=args.max_batch_size,
|
||
|
num_nodes=num_nodes,
|
||
|
num_devices_per_node=num_devices_per_node,
|
||
|
distributed_init_method=distributed_init_method,
|
||
|
all_stage_devices=all_stage_devices,
|
||
|
gpu_memory=get_gpu_memory(),
|
||
|
cpu_memory=get_cpu_memory(),
|
||
|
)
|
||
|
|
||
|
# Create a frontend.
|
||
|
frontend = SimpleFrontend(
|
||
|
model_name=args.model,
|
||
|
block_size=args.block_size,
|
||
|
)
|
||
|
|
||
|
# Test the following inputs.
|
||
|
test_inputs = [
|
||
|
('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}),
|
||
|
('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}),
|
||
|
('The future of cloud computing is', {}), # Use default parameters.
|
||
|
]
|
||
|
while True:
|
||
|
if test_inputs:
|
||
|
text, sampling_params_dict = test_inputs.pop(0)
|
||
|
sampling_params = SamplingParams.from_dict(sampling_params_dict)
|
||
|
sampling_params = frontend.add_eos_token(sampling_params)
|
||
|
frontend.query(text, sampling_params)
|
||
|
server.add_sequence_groups(frontend.get_inputs())
|
||
|
updated_seq_groups = server.step()
|
||
|
for seq_group in updated_seq_groups:
|
||
|
if seq_group.is_finished():
|
||
|
frontend.print_response(seq_group)
|
||
|
if not (server.has_unfinished_requests() or test_inputs):
|
||
|
break
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
|
||
|
parser = add_server_arguments(parser)
|
||
|
args = parser.parse_args()
|
||
|
main(args)
|