import argparse from typing import List from cacheflow.master.simple_frontend import SimpleFrontend from cacheflow.master.server import (Server, add_server_arguments, initialize_ray_cluster) from cacheflow.sampling_params import SamplingParams from cacheflow.utils import get_gpu_memory, get_cpu_memory def main(args: argparse.Namespace): # TODO(zhuohan): Support pipeline parallelism. assert args.pipeline_parallel_size == 1, ( 'Pipeline parallelism is not supported yet.') (num_nodes, num_devices_per_node, distributed_init_method, all_stage_devices) = ( initialize_ray_cluster( pipeline_parallel_size=args.pipeline_parallel_size, tensor_parallel_size=args.tensor_parallel_size)) # Create a server. server = Server( model=args.model, model_path=args.model_path, use_dummy_weights=args.use_dummy_weights, pipeline_parallel_size=args.pipeline_parallel_size, tensor_parallel_size=args.tensor_parallel_size, block_size=args.block_size, dtype=args.dtype, seed=args.seed, swap_space=args.swap_space, max_num_batched_tokens=args.max_num_batched_tokens, num_nodes=num_nodes, num_devices_per_node=num_devices_per_node, distributed_init_method=distributed_init_method, all_stage_devices=all_stage_devices, gpu_memory=get_gpu_memory(), cpu_memory=get_cpu_memory(), ) # Create a frontend. frontend = SimpleFrontend( model_name=args.model, block_size=args.block_size, ) # Test the following inputs. test_inputs = [ ('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}), ('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}), ('The future of cloud computing is', {}), # Use default parameters. ] while True: if test_inputs: text, sampling_params_dict = test_inputs.pop(0) sampling_params = SamplingParams.from_dict(sampling_params_dict) sampling_params = frontend.add_eos_token(sampling_params) frontend.query(text, sampling_params) server.add_sequence_groups(frontend.get_inputs()) updated_seq_groups = server.step() for seq_group in updated_seq_groups: if seq_group.is_finished(): frontend.print_response(seq_group) if not (server.has_unfinished_requests() or test_inputs): break if __name__ == '__main__': parser = argparse.ArgumentParser(description='CacheFlow simple server.') parser = add_server_arguments(parser) args = parser.parse_args() main(args)