vllm/simple_server.py

39 lines
1.6 KiB
Python
Raw Normal View History

2023-03-29 14:48:56 +08:00
import argparse
2023-05-09 15:30:12 -07:00
from cacheflow.core.server import (
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
2023-03-29 14:48:56 +08:00
from cacheflow.sampling_params import SamplingParams
2023-05-08 23:03:35 -07:00
2023-03-29 14:48:56 +08:00
def main(args: argparse.Namespace):
server, frontend = init_local_server_and_frontend_with_arguments(args)
2023-03-29 14:48:56 +08:00
# Test the following inputs.
test_inputs = [
("A robot may not injure a human being", {}), # Use default parameters.
("To be or not to be,", {"temperature": 0.8, "top_k": 5, "presence_penalty": 0.2}),
("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95, "frequency_penalty": 0.1}),
2023-05-10 12:51:36 -07:00
("It is only with the heart that one can see rightly", {"n": 3, "use_beam_search": True, "temperature": 0.0}),
2023-03-29 14:48:56 +08:00
]
while True:
if test_inputs:
text, sampling_params_dict = test_inputs.pop(0)
2023-05-11 15:45:30 -07:00
sampling_params = SamplingParams(**sampling_params_dict)
2023-03-29 14:48:56 +08:00
sampling_params = frontend.add_eos_token(sampling_params)
frontend.query(text, sampling_params)
server.add_sequence_groups(frontend.get_inputs())
updated_seq_groups = server.step()
for seq_group in updated_seq_groups:
if seq_group.is_finished():
frontend.print_response(seq_group)
if not (server.has_unfinished_requests() or test_inputs):
break
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
parser = add_server_arguments(parser)
args = parser.parse_args()
args = process_server_arguments(args)
2023-03-29 14:48:56 +08:00
main(args)