Fix progress bar and allow HTTPS in benchmark_serving.py (#2552)

This commit is contained in:
Harry Mellor 2024-01-22 22:40:31 +00:00 committed by GitHub
parent 94b5edeb53
commit 63e835cbcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -101,6 +101,7 @@ async def send_request(
output_len: int,
best_of: int,
use_beam_search: bool,
pbar: tqdm
) -> None:
request_start_time = time.perf_counter()
@ -151,6 +152,8 @@ async def send_request(
request_end_time = time.perf_counter()
request_latency = request_end_time - request_start_time
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
pbar.update(1)
async def benchmark(
@ -163,13 +166,15 @@ async def benchmark(
request_rate: float,
) -> None:
tasks: List[asyncio.Task] = []
pbar = tqdm(total=len(input_requests))
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
task = asyncio.create_task(
send_request(backend, model, api_url, prompt, prompt_len,
output_len, best_of, use_beam_search))
output_len, best_of, use_beam_search, pbar))
tasks.append(task)
await tqdm.gather(*tasks)
await asyncio.gather(*tasks)
pbar.close()
def main(args: argparse.Namespace):
@ -177,7 +182,7 @@ def main(args: argparse.Namespace):
random.seed(args.seed)
np.random.seed(args.seed)
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}"
tokenizer = get_tokenizer(args.tokenizer,
trust_remote_code=args.trust_remote_code)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
@ -212,6 +217,7 @@ if __name__ == "__main__":
type=str,
default="vllm",
choices=["vllm", "tgi"])
parser.add_argument("--protocol", type=str, default="http", choices=["http", "https"])
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--endpoint", type=str, default="/generate")