Fix progress bar and allow HTTPS in benchmark_serving.py
(#2552)
This commit is contained in:
parent
94b5edeb53
commit
63e835cbcc
@ -101,6 +101,7 @@ async def send_request(
|
|||||||
output_len: int,
|
output_len: int,
|
||||||
best_of: int,
|
best_of: int,
|
||||||
use_beam_search: bool,
|
use_beam_search: bool,
|
||||||
|
pbar: tqdm
|
||||||
) -> None:
|
) -> None:
|
||||||
request_start_time = time.perf_counter()
|
request_start_time = time.perf_counter()
|
||||||
|
|
||||||
@ -151,6 +152,8 @@ async def send_request(
|
|||||||
request_end_time = time.perf_counter()
|
request_end_time = time.perf_counter()
|
||||||
request_latency = request_end_time - request_start_time
|
request_latency = request_end_time - request_start_time
|
||||||
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
|
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def benchmark(
|
async def benchmark(
|
||||||
@ -163,13 +166,15 @@ async def benchmark(
|
|||||||
request_rate: float,
|
request_rate: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
tasks: List[asyncio.Task] = []
|
tasks: List[asyncio.Task] = []
|
||||||
|
pbar = tqdm(total=len(input_requests))
|
||||||
async for request in get_request(input_requests, request_rate):
|
async for request in get_request(input_requests, request_rate):
|
||||||
prompt, prompt_len, output_len = request
|
prompt, prompt_len, output_len = request
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
send_request(backend, model, api_url, prompt, prompt_len,
|
send_request(backend, model, api_url, prompt, prompt_len,
|
||||||
output_len, best_of, use_beam_search))
|
output_len, best_of, use_beam_search, pbar))
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
await tqdm.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
@ -177,7 +182,7 @@ def main(args: argparse.Namespace):
|
|||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}"
|
||||||
tokenizer = get_tokenizer(args.tokenizer,
|
tokenizer = get_tokenizer(args.tokenizer,
|
||||||
trust_remote_code=args.trust_remote_code)
|
trust_remote_code=args.trust_remote_code)
|
||||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||||
@ -212,6 +217,7 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
default="vllm",
|
default="vllm",
|
||||||
choices=["vllm", "tgi"])
|
choices=["vllm", "tgi"])
|
||||||
|
parser.add_argument("--protocol", type=str, default="http", choices=["http", "https"])
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
parser.add_argument("--port", type=int, default=8000)
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
parser.add_argument("--endpoint", type=str, default="/generate")
|
parser.add_argument("--endpoint", type=str, default="/generate")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user