2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2024-07-17 15:43:21 +08:00
|
|
|
"""Example Python client for `vllm.entrypoints.api_server`
|
2024-07-01 10:50:56 -07:00
|
|
|
NOTE: The API server is used only for demonstration and simple performance
|
|
|
|
benchmarks. It is not intended for production use.
|
2024-07-17 15:43:21 +08:00
|
|
|
For production use, we recommend `vllm serve` and the OpenAI client API.
|
2024-07-01 10:50:56 -07:00
|
|
|
"""
|
2023-06-17 00:13:02 +08:00
|
|
|
|
2023-05-23 21:39:50 -07:00
|
|
|
import argparse
|
|
|
|
import json
|
2023-06-11 01:43:07 +08:00
|
|
|
from typing import Iterable, List
|
2023-05-23 21:39:50 -07:00
|
|
|
|
2023-06-14 19:55:38 -07:00
|
|
|
import requests
|
|
|
|
|
|
|
|
|
|
|
|
def clear_line(n: int = 1) -> None:
|
2023-05-23 21:39:50 -07:00
|
|
|
LINE_UP = '\033[1A'
|
|
|
|
LINE_CLEAR = '\x1b[2K'
|
2023-06-14 19:55:38 -07:00
|
|
|
for _ in range(n):
|
2023-05-23 21:39:50 -07:00
|
|
|
print(LINE_UP, end=LINE_CLEAR, flush=True)
|
|
|
|
|
|
|
|
|
2023-07-03 11:31:55 -07:00
|
|
|
def post_http_request(prompt: str,
|
|
|
|
api_url: str,
|
|
|
|
n: int = 1,
|
2023-06-11 01:43:07 +08:00
|
|
|
stream: bool = False) -> requests.Response:
|
2023-05-23 21:39:50 -07:00
|
|
|
headers = {"User-Agent": "Test Client"}
|
|
|
|
pload = {
|
|
|
|
"prompt": prompt,
|
|
|
|
"n": n,
|
|
|
|
"use_beam_search": True,
|
|
|
|
"temperature": 0.0,
|
|
|
|
"max_tokens": 16,
|
2023-06-11 01:43:07 +08:00
|
|
|
"stream": stream,
|
2023-05-23 21:39:50 -07:00
|
|
|
}
|
2024-07-27 17:07:02 +08:00
|
|
|
response = requests.post(api_url,
|
|
|
|
headers=headers,
|
|
|
|
json=pload,
|
|
|
|
stream=stream)
|
2023-06-11 01:43:07 +08:00
|
|
|
return response
|
|
|
|
|
2023-05-23 21:39:50 -07:00
|
|
|
|
2023-06-11 01:43:07 +08:00
|
|
|
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
|
2023-07-03 11:31:55 -07:00
|
|
|
for chunk in response.iter_lines(chunk_size=8192,
|
|
|
|
decode_unicode=False,
|
2023-06-11 01:43:07 +08:00
|
|
|
delimiter=b"\0"):
|
2023-05-23 21:39:50 -07:00
|
|
|
if chunk:
|
|
|
|
data = json.loads(chunk.decode("utf-8"))
|
|
|
|
output = data["text"]
|
|
|
|
yield output
|
|
|
|
|
|
|
|
|
2023-06-11 01:43:07 +08:00
|
|
|
def get_response(response: requests.Response) -> List[str]:
|
|
|
|
data = json.loads(response.content)
|
|
|
|
output = data["text"]
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
2023-05-23 21:39:50 -07:00
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--host", type=str, default="localhost")
|
2023-06-17 00:13:02 +08:00
|
|
|
parser.add_argument("--port", type=int, default=8000)
|
2023-05-23 21:39:50 -07:00
|
|
|
parser.add_argument("--n", type=int, default=4)
|
|
|
|
parser.add_argument("--prompt", type=str, default="San Francisco is a")
|
2023-06-11 01:43:07 +08:00
|
|
|
parser.add_argument("--stream", action="store_true")
|
2023-05-23 21:39:50 -07:00
|
|
|
args = parser.parse_args()
|
|
|
|
prompt = args.prompt
|
|
|
|
api_url = f"http://{args.host}:{args.port}/generate"
|
|
|
|
n = args.n
|
2023-06-11 01:43:07 +08:00
|
|
|
stream = args.stream
|
2023-05-23 21:39:50 -07:00
|
|
|
|
2023-06-14 19:55:38 -07:00
|
|
|
print(f"Prompt: {prompt!r}\n", flush=True)
|
2023-06-11 01:43:07 +08:00
|
|
|
response = post_http_request(prompt, api_url, n, stream)
|
|
|
|
|
|
|
|
if stream:
|
2023-05-23 21:39:50 -07:00
|
|
|
num_printed_lines = 0
|
2023-06-11 01:43:07 +08:00
|
|
|
for h in get_streaming_response(response):
|
|
|
|
clear_line(num_printed_lines)
|
|
|
|
num_printed_lines = 0
|
|
|
|
for i, line in enumerate(h):
|
|
|
|
num_printed_lines += 1
|
2023-06-14 19:55:38 -07:00
|
|
|
print(f"Beam candidate {i}: {line!r}", flush=True)
|
2023-06-11 01:43:07 +08:00
|
|
|
else:
|
|
|
|
output = get_response(response)
|
|
|
|
for i, line in enumerate(output):
|
2023-06-14 19:55:38 -07:00
|
|
|
print(f"Beam candidate {i}: {line!r}", flush=True)
|