vllm/simple_server.py
2023-04-08 23:36:12 -07:00

73 lines
2.7 KiB
Python

import argparse
from typing import List
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
def main(args: argparse.Namespace):
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
# Create a server.
server = Server(
model=args.model,
model_path=args.model_path,
use_dummy_weights=args.use_dummy_weights,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
)
# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
# Test the following inputs.
test_inputs = [
('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}),
('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}),
('The future of cloud computing is', {}), # Use default parameters.
]
while True:
if test_inputs:
text, sampling_params_dict = test_inputs.pop(0)
sampling_params = SamplingParams.from_dict(sampling_params_dict)
sampling_params = frontend.add_eos_token(sampling_params)
frontend.query(text, sampling_params)
server.add_sequence_groups(frontend.get_inputs())
updated_seq_groups = server.step()
for seq_group in updated_seq_groups:
if seq_group.is_finished():
frontend.print_response(seq_group)
if not (server.has_unfinished_requests() or test_inputs):
break
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
parser = add_server_arguments(parser)
args = parser.parse_args()
main(args)