2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2024-09-27 16:13:25 +08:00
|
|
|
r"""Benchmark online serving throughput.
|
2023-06-14 19:55:38 -07:00
|
|
|
|
|
|
|
On the server side, run one of the following commands:
|
2024-03-27 13:39:26 -07:00
|
|
|
vLLM OpenAI API server
|
2024-07-14 15:36:43 -07:00
|
|
|
vllm serve <your_model> \
|
|
|
|
--swap-space 16 \
|
2023-06-18 11:39:35 -07:00
|
|
|
--disable-log-requests
|
2023-06-14 19:55:38 -07:00
|
|
|
|
|
|
|
(TGI backend)
|
2024-02-22 04:18:37 +02:00
|
|
|
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
|
2023-06-14 19:55:38 -07:00
|
|
|
|
|
|
|
On the client side, run:
|
|
|
|
python benchmarks/benchmark_serving.py \
|
|
|
|
--backend <backend> \
|
2024-03-27 13:39:26 -07:00
|
|
|
--model <your_model> \
|
|
|
|
--dataset-name sharegpt \
|
|
|
|
--dataset-path <path to dataset> \
|
|
|
|
--request-rate <request_rate> \ # By default <request_rate> is inf
|
|
|
|
--num-prompts <num_prompts> # By default <num_prompts> is 1000
|
2024-07-07 15:42:13 +08:00
|
|
|
|
2024-05-20 13:16:57 -07:00
|
|
|
when using tgi backend, add
|
|
|
|
--endpoint /generate_stream
|
|
|
|
to the end of the command above.
|
2023-06-14 19:55:38 -07:00
|
|
|
"""
|
|
|
|
import argparse
|
|
|
|
import asyncio
|
2024-09-17 15:34:27 +08:00
|
|
|
import base64
|
2025-01-21 21:46:14 -08:00
|
|
|
import gc
|
2024-09-17 15:34:27 +08:00
|
|
|
import io
|
2023-06-14 19:55:38 -07:00
|
|
|
import json
|
2024-03-27 13:39:26 -07:00
|
|
|
import os
|
2023-06-14 19:55:38 -07:00
|
|
|
import random
|
|
|
|
import time
|
2024-03-27 13:39:26 -07:00
|
|
|
import warnings
|
2024-02-12 22:53:00 -08:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from datetime import datetime
|
2024-09-17 15:34:27 +08:00
|
|
|
from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple
|
2023-06-14 19:55:38 -07:00
|
|
|
|
|
|
|
import numpy as np
|
2024-03-25 23:59:47 +09:00
|
|
|
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
|
|
|
RequestFuncOutput)
|
2024-09-17 15:34:27 +08:00
|
|
|
from datasets import load_dataset
|
|
|
|
from PIL.Image import Image
|
2024-01-19 04:34:08 +00:00
|
|
|
from tqdm.asyncio import tqdm
|
2023-06-28 09:46:58 -07:00
|
|
|
from transformers import PreTrainedTokenizerBase
|
2023-06-14 19:55:38 -07:00
|
|
|
|
2024-06-18 00:40:35 +08:00
|
|
|
try:
|
|
|
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
|
|
|
except ImportError:
|
|
|
|
from backend_request_func import get_tokenizer
|
2024-02-12 22:53:00 -08:00
|
|
|
|
2024-06-20 19:00:13 -04:00
|
|
|
try:
|
|
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
except ImportError:
|
|
|
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
|
|
|
|
2024-10-20 11:39:32 -07:00
|
|
|
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
|
|
|
|
2024-02-12 22:53:00 -08:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class BenchmarkMetrics:
|
|
|
|
completed: int
|
|
|
|
total_input: int
|
|
|
|
total_output: int
|
|
|
|
request_throughput: float
|
2024-10-20 11:39:32 -07:00
|
|
|
request_goodput: float
|
2024-02-12 22:53:00 -08:00
|
|
|
output_throughput: float
|
2024-09-04 13:23:22 -07:00
|
|
|
total_token_throughput: float
|
2024-02-12 22:53:00 -08:00
|
|
|
mean_ttft_ms: float
|
|
|
|
median_ttft_ms: float
|
2024-07-11 13:28:38 -07:00
|
|
|
std_ttft_ms: float
|
2024-08-29 16:48:11 -07:00
|
|
|
percentiles_ttft_ms: List[Tuple[float, float]]
|
2024-02-12 22:53:00 -08:00
|
|
|
mean_tpot_ms: float
|
|
|
|
median_tpot_ms: float
|
2024-07-11 13:28:38 -07:00
|
|
|
std_tpot_ms: float
|
2024-08-29 16:48:11 -07:00
|
|
|
percentiles_tpot_ms: List[Tuple[float, float]]
|
2024-06-05 13:17:51 -04:00
|
|
|
mean_itl_ms: float
|
|
|
|
median_itl_ms: float
|
2024-07-11 13:28:38 -07:00
|
|
|
std_itl_ms: float
|
2024-08-29 16:48:11 -07:00
|
|
|
percentiles_itl_ms: List[Tuple[float, float]]
|
|
|
|
# E2EL stands for end-to-end latency per request.
|
|
|
|
# It is the time taken on the client side from sending
|
|
|
|
# a request to receiving a complete response.
|
|
|
|
mean_e2el_ms: float
|
|
|
|
median_e2el_ms: float
|
|
|
|
std_e2el_ms: float
|
|
|
|
percentiles_e2el_ms: List[Tuple[float, float]]
|
2023-06-14 19:55:38 -07:00
|
|
|
|
|
|
|
|
2024-03-27 13:39:26 -07:00
|
|
|
def sample_sharegpt_requests(
|
2023-06-14 19:55:38 -07:00
|
|
|
dataset_path: str,
|
|
|
|
num_requests: int,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
2024-04-24 09:49:13 -07:00
|
|
|
fixed_output_len: Optional[int] = None,
|
2024-09-17 15:34:27 +08:00
|
|
|
) -> List[Tuple[str, int, int, None]]:
|
2023-06-14 19:55:38 -07:00
|
|
|
# Load the dataset.
|
2024-10-01 21:07:06 +03:00
|
|
|
with open(dataset_path, encoding='utf-8') as f:
|
2023-06-14 19:55:38 -07:00
|
|
|
dataset = json.load(f)
|
|
|
|
# Filter out the conversations with less than 2 turns.
|
2024-01-19 04:34:08 +00:00
|
|
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
2023-06-14 19:55:38 -07:00
|
|
|
# Only keep the first two turns of each conversation.
|
2024-01-19 04:34:08 +00:00
|
|
|
dataset = [(data["conversations"][0]["value"],
|
|
|
|
data["conversations"][1]["value"]) for data in dataset]
|
2023-06-14 19:55:38 -07:00
|
|
|
|
2024-04-24 09:49:13 -07:00
|
|
|
# Shuffle the dataset.
|
|
|
|
random.shuffle(dataset)
|
2023-06-14 19:55:38 -07:00
|
|
|
|
2024-04-24 09:49:13 -07:00
|
|
|
# Filter out sequences that are too long or too short
|
2023-06-14 19:55:38 -07:00
|
|
|
filtered_dataset: List[Tuple[str, int, int]] = []
|
2024-04-24 09:49:13 -07:00
|
|
|
for i in range(len(dataset)):
|
|
|
|
if len(filtered_dataset) == num_requests:
|
|
|
|
break
|
|
|
|
|
|
|
|
# Tokenize the prompts and completions.
|
|
|
|
prompt = dataset[i][0]
|
|
|
|
prompt_token_ids = tokenizer(prompt).input_ids
|
|
|
|
completion = dataset[i][1]
|
|
|
|
completion_token_ids = tokenizer(completion).input_ids
|
2023-06-14 19:55:38 -07:00
|
|
|
prompt_len = len(prompt_token_ids)
|
2024-04-24 09:49:13 -07:00
|
|
|
output_len = len(completion_token_ids
|
|
|
|
) if fixed_output_len is None else fixed_output_len
|
2024-09-28 11:51:22 -07:00
|
|
|
if prompt_len < 4 or (fixed_output_len is None and output_len < 4):
|
2023-06-14 19:55:38 -07:00
|
|
|
# Prune too short sequences.
|
|
|
|
continue
|
|
|
|
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
|
|
|
# Prune too long sequences.
|
|
|
|
continue
|
2024-09-17 15:34:27 +08:00
|
|
|
filtered_dataset.append((prompt, prompt_len, output_len, None))
|
2023-06-14 19:55:38 -07:00
|
|
|
|
2024-04-24 09:49:13 -07:00
|
|
|
return filtered_dataset
|
2023-06-14 19:55:38 -07:00
|
|
|
|
|
|
|
|
2024-03-27 13:39:26 -07:00
|
|
|
def sample_sonnet_requests(
|
|
|
|
dataset_path: str,
|
|
|
|
num_requests: int,
|
|
|
|
input_len: int,
|
|
|
|
output_len: int,
|
|
|
|
prefix_len: int,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
2024-09-17 15:34:27 +08:00
|
|
|
) -> List[Tuple[str, str, int, int, None]]:
|
2024-04-04 00:41:05 -07:00
|
|
|
assert (
|
|
|
|
input_len > prefix_len
|
|
|
|
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
|
2024-03-27 13:39:26 -07:00
|
|
|
|
|
|
|
# Load the dataset.
|
2024-10-01 21:07:06 +03:00
|
|
|
with open(dataset_path, encoding='utf-8') as f:
|
2024-03-27 13:39:26 -07:00
|
|
|
poem_lines = f.readlines()
|
|
|
|
|
|
|
|
# Tokenize the poem lines.
|
|
|
|
poem_token_ids = tokenizer(poem_lines).input_ids
|
|
|
|
average_poem_len = sum(
|
|
|
|
len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids)
|
|
|
|
|
|
|
|
# Base prefix for all requests.
|
|
|
|
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
|
|
|
base_message = [{
|
|
|
|
"role": "user",
|
|
|
|
"content": base_prompt,
|
|
|
|
}]
|
|
|
|
base_prompt_formatted = tokenizer.apply_chat_template(
|
|
|
|
base_message, add_generation_prompt=True, tokenize=False)
|
|
|
|
base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
|
|
|
|
|
2024-04-04 00:41:05 -07:00
|
|
|
assert (
|
|
|
|
input_len > base_prompt_offset
|
|
|
|
), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
|
2024-03-27 13:39:26 -07:00
|
|
|
num_input_lines = round(
|
|
|
|
(input_len - base_prompt_offset) / average_poem_len)
|
|
|
|
|
|
|
|
# First approximately `prefix_len` number of tokens in the
|
|
|
|
# prompt are fixed poem lines.
|
|
|
|
assert (
|
|
|
|
prefix_len > base_prompt_offset
|
2024-04-04 00:41:05 -07:00
|
|
|
), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."
|
2024-03-27 13:39:26 -07:00
|
|
|
|
|
|
|
num_prefix_lines = round(
|
|
|
|
(prefix_len - base_prompt_offset) / average_poem_len)
|
|
|
|
prefix_lines = poem_lines[:num_prefix_lines]
|
|
|
|
|
|
|
|
# Sample the rest of lines per request.
|
|
|
|
sampled_requests: List[Tuple[str, int, int]] = []
|
|
|
|
for _ in range(num_requests):
|
2024-10-10 17:33:16 -07:00
|
|
|
num_lines_needed = num_input_lines - num_prefix_lines
|
|
|
|
sampled_lines = "".join(prefix_lines +
|
|
|
|
random.choices(poem_lines, k=num_lines_needed))
|
2024-03-27 13:39:26 -07:00
|
|
|
|
|
|
|
prompt = f"{base_prompt}{sampled_lines}"
|
|
|
|
message = [
|
|
|
|
{
|
|
|
|
"role": "user",
|
|
|
|
"content": prompt,
|
|
|
|
},
|
|
|
|
]
|
|
|
|
prompt_formatted = tokenizer.apply_chat_template(
|
|
|
|
message, add_generation_prompt=True, tokenize=False)
|
|
|
|
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
|
|
|
sampled_requests.append(
|
2024-09-17 15:34:27 +08:00
|
|
|
(prompt, prompt_formatted, prompt_len, output_len, None))
|
|
|
|
|
|
|
|
return sampled_requests
|
|
|
|
|
|
|
|
|
2025-01-24 00:22:04 -08:00
|
|
|
def sample_vision_arena_requests(
|
2024-12-01 00:47:05 -08:00
|
|
|
dataset,
|
|
|
|
num_requests: int,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
|
fixed_output_len: Optional[int] = None,
|
|
|
|
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
|
|
|
|
sampled_requests: List[Tuple[str, int, int, Dict[str,
|
|
|
|
Collection[str]]]] = []
|
|
|
|
for data in dataset:
|
|
|
|
if len(sampled_requests) == num_requests:
|
|
|
|
break
|
|
|
|
|
2025-01-24 00:22:04 -08:00
|
|
|
prompt = data["turns"][0][0]['content']
|
2024-12-01 00:47:05 -08:00
|
|
|
|
|
|
|
prompt_token_ids = tokenizer(prompt).input_ids
|
|
|
|
if fixed_output_len is None:
|
|
|
|
# Default max output len is set to 128
|
|
|
|
print("--hf-output-len is not provided. Using default value 128.")
|
|
|
|
fixed_output_len = 128
|
|
|
|
|
|
|
|
prompt_len = len(prompt_token_ids)
|
|
|
|
output_len = fixed_output_len
|
|
|
|
|
|
|
|
assert isinstance(
|
2025-01-24 00:22:04 -08:00
|
|
|
data["images"][0],
|
2024-12-01 00:47:05 -08:00
|
|
|
Image), ("Input image format must be `PIL.Image.Image`, "
|
|
|
|
f"given {type(data['image'])}.")
|
2025-01-24 00:22:04 -08:00
|
|
|
image: Image = data["images"][0]
|
2024-12-01 00:47:05 -08:00
|
|
|
image = image.convert("RGB")
|
|
|
|
image_data = io.BytesIO()
|
|
|
|
image.save(image_data, format='JPEG')
|
|
|
|
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
|
|
|
|
mm_content = {
|
|
|
|
"type": "image_url",
|
|
|
|
"image_url": {
|
|
|
|
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
sampled_requests.append((prompt, prompt_len, output_len, mm_content))
|
|
|
|
|
|
|
|
return sampled_requests
|
|
|
|
|
|
|
|
|
2024-09-17 15:34:27 +08:00
|
|
|
def sample_hf_requests(
|
|
|
|
dataset_path: str,
|
2025-01-24 00:22:04 -08:00
|
|
|
dataset_subset: Optional[str],
|
2024-09-17 15:34:27 +08:00
|
|
|
dataset_split: str,
|
|
|
|
num_requests: int,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
2024-10-17 14:11:11 -07:00
|
|
|
random_seed: int,
|
2024-09-17 15:34:27 +08:00
|
|
|
fixed_output_len: Optional[int] = None,
|
|
|
|
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
|
2024-12-01 00:47:05 -08:00
|
|
|
|
2025-01-24 00:22:04 -08:00
|
|
|
# Special case for vision_arena dataset
|
|
|
|
if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \
|
|
|
|
and dataset_subset is None:
|
|
|
|
assert dataset_split == "train"
|
2024-12-01 00:47:05 -08:00
|
|
|
dataset = load_dataset(dataset_path,
|
|
|
|
name=dataset_subset,
|
|
|
|
split=dataset_split,
|
|
|
|
streaming=True)
|
2025-01-24 00:22:04 -08:00
|
|
|
dataset = dataset.shuffle(seed=random_seed)
|
|
|
|
return sample_vision_arena_requests(dataset, num_requests, tokenizer,
|
|
|
|
fixed_output_len)
|
2024-12-01 00:47:05 -08:00
|
|
|
|
2024-09-17 15:34:27 +08:00
|
|
|
dataset = load_dataset(dataset_path,
|
|
|
|
name=dataset_subset,
|
|
|
|
split=dataset_split,
|
|
|
|
streaming=True)
|
|
|
|
assert "conversations" in dataset.features, (
|
|
|
|
"HF Dataset must have 'conversations' column.")
|
2024-10-17 14:11:11 -07:00
|
|
|
filter_func = lambda x: len(x["conversations"]) >= 2
|
|
|
|
filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
|
2024-09-17 15:34:27 +08:00
|
|
|
sampled_requests: List[Tuple[str, int, int, Dict[str,
|
|
|
|
Collection[str]]]] = []
|
|
|
|
for data in filtered_dataset:
|
|
|
|
if len(sampled_requests) == num_requests:
|
|
|
|
break
|
|
|
|
|
|
|
|
# Tokenize the prompts and completions.
|
|
|
|
prompt = data["conversations"][0]["value"]
|
|
|
|
prompt_token_ids = tokenizer(prompt).input_ids
|
|
|
|
completion = data["conversations"][1]["value"]
|
|
|
|
completion_token_ids = tokenizer(completion).input_ids
|
|
|
|
prompt_len = len(prompt_token_ids)
|
|
|
|
output_len = len(completion_token_ids
|
|
|
|
) if fixed_output_len is None else fixed_output_len
|
2024-09-28 11:51:22 -07:00
|
|
|
if fixed_output_len is None and (prompt_len < 4 or output_len < 4):
|
2024-09-17 15:34:27 +08:00
|
|
|
# Prune too short sequences.
|
|
|
|
continue
|
2024-09-28 11:51:22 -07:00
|
|
|
if fixed_output_len is None and \
|
|
|
|
(prompt_len > 1024 or prompt_len + output_len > 2048):
|
2024-09-17 15:34:27 +08:00
|
|
|
# Prune too long sequences.
|
|
|
|
continue
|
|
|
|
|
|
|
|
if "image" in data and isinstance(data["image"], Image):
|
|
|
|
image: Image = data["image"]
|
|
|
|
image = image.convert("RGB")
|
|
|
|
image_data = io.BytesIO()
|
|
|
|
image.save(image_data, format='JPEG')
|
|
|
|
image_base64 = base64.b64encode(
|
|
|
|
image_data.getvalue()).decode("utf-8")
|
|
|
|
mm_content = {
|
|
|
|
"type": "image_url",
|
|
|
|
"image_url": {
|
|
|
|
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
|
|
},
|
|
|
|
}
|
2024-11-16 19:15:40 +09:00
|
|
|
elif "image" in data and isinstance(data["image"], str):
|
|
|
|
if (data["image"].startswith("http://") or \
|
|
|
|
data["image"].startswith("file://")):
|
|
|
|
image_url = data["image"]
|
|
|
|
else:
|
|
|
|
image_url = f"file://{data['image']}"
|
|
|
|
|
|
|
|
mm_content = {
|
|
|
|
"type": "image_url",
|
|
|
|
"image_url": {
|
|
|
|
"url": image_url
|
|
|
|
},
|
|
|
|
}
|
2024-09-17 15:34:27 +08:00
|
|
|
else:
|
|
|
|
mm_content = None
|
|
|
|
|
|
|
|
sampled_requests.append((prompt, prompt_len, output_len, mm_content))
|
2024-03-27 13:39:26 -07:00
|
|
|
|
|
|
|
return sampled_requests
|
|
|
|
|
|
|
|
|
2024-07-07 15:42:13 +08:00
|
|
|
def sample_random_requests(
|
2024-09-06 20:18:16 -07:00
|
|
|
prefix_len: int,
|
|
|
|
input_len: int,
|
|
|
|
output_len: int,
|
|
|
|
num_prompts: int,
|
|
|
|
range_ratio: float,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
|
) -> List[Tuple[str, int, int]]:
|
|
|
|
prefix_token_ids = np.random.randint(0,
|
|
|
|
tokenizer.vocab_size,
|
|
|
|
size=prefix_len).tolist()
|
2024-07-07 15:42:13 +08:00
|
|
|
|
|
|
|
input_lens = np.random.randint(
|
|
|
|
int(input_len * range_ratio),
|
|
|
|
input_len + 1,
|
|
|
|
size=num_prompts,
|
|
|
|
)
|
|
|
|
output_lens = np.random.randint(
|
|
|
|
int(output_len * range_ratio),
|
|
|
|
output_len + 1,
|
|
|
|
size=num_prompts,
|
|
|
|
)
|
|
|
|
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
|
|
|
input_requests = []
|
2024-07-15 10:27:01 +08:00
|
|
|
for i in range(num_prompts):
|
2024-09-06 20:18:16 -07:00
|
|
|
prompt = tokenizer.decode(prefix_token_ids +
|
|
|
|
[(offsets[i] + i + j) % tokenizer.vocab_size
|
2024-07-07 15:42:13 +08:00
|
|
|
for j in range(input_lens[i])])
|
2024-09-06 20:18:16 -07:00
|
|
|
|
2024-09-17 15:34:27 +08:00
|
|
|
input_requests.append((prompt, int(prefix_len + input_lens[i]),
|
|
|
|
int(output_lens[i]), None))
|
2024-07-07 15:42:13 +08:00
|
|
|
|
|
|
|
return input_requests
|
|
|
|
|
|
|
|
|
2023-06-14 19:55:38 -07:00
|
|
|
async def get_request(
|
|
|
|
input_requests: List[Tuple[str, int, int]],
|
|
|
|
request_rate: float,
|
2024-11-07 19:20:30 +08:00
|
|
|
burstiness: float = 1.0,
|
2023-06-14 19:55:38 -07:00
|
|
|
) -> AsyncGenerator[Tuple[str, int, int], None]:
|
2024-11-07 19:20:30 +08:00
|
|
|
"""
|
|
|
|
Asynchronously generates requests at a specified rate
|
|
|
|
with OPTIONAL burstiness.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_requests:
|
|
|
|
A list of input requests, each represented as a tuple.
|
|
|
|
request_rate:
|
|
|
|
The rate at which requests are generated (requests/s).
|
|
|
|
burstiness (optional):
|
|
|
|
The burstiness factor of the request generation.
|
|
|
|
Only takes effect when request_rate is not inf.
|
|
|
|
Default value is 1, which follows a Poisson process.
|
|
|
|
Otherwise, the request intervals follow a gamma distribution.
|
|
|
|
A lower burstiness value (0 < burstiness < 1) results
|
|
|
|
in more bursty requests, while a higher burstiness value
|
|
|
|
(burstiness > 1) results in a more uniform arrival of requests.
|
|
|
|
"""
|
2023-06-14 19:55:38 -07:00
|
|
|
input_requests = iter(input_requests)
|
2024-11-07 19:20:30 +08:00
|
|
|
|
|
|
|
# Calculate scale parameter theta to maintain the desired request_rate.
|
|
|
|
assert burstiness > 0, (
|
|
|
|
f"A positive burstiness factor is expected, but given {burstiness}.")
|
|
|
|
theta = 1.0 / (request_rate * burstiness)
|
|
|
|
|
2023-06-14 19:55:38 -07:00
|
|
|
for request in input_requests:
|
|
|
|
yield request
|
|
|
|
|
|
|
|
if request_rate == float("inf"):
|
|
|
|
# If the request rate is infinity, then we don't need to wait.
|
|
|
|
continue
|
2024-07-07 15:42:13 +08:00
|
|
|
|
2024-11-07 19:20:30 +08:00
|
|
|
# Sample the request interval from the gamma distribution.
|
|
|
|
# If burstiness is 1, it follows exponential distribution.
|
|
|
|
interval = np.random.gamma(shape=burstiness, scale=theta)
|
2023-06-14 19:55:38 -07:00
|
|
|
# The next request will be sent after the interval.
|
|
|
|
await asyncio.sleep(interval)
|
|
|
|
|
|
|
|
|
2024-02-12 22:53:00 -08:00
|
|
|
def calculate_metrics(
|
|
|
|
input_requests: List[Tuple[str, int, int]],
|
|
|
|
outputs: List[RequestFuncOutput],
|
|
|
|
dur_s: float,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
2024-08-29 16:48:11 -07:00
|
|
|
selected_percentile_metrics: List[str],
|
|
|
|
selected_percentiles: List[float],
|
2025-01-21 21:46:14 -08:00
|
|
|
goodput_config_dict: Dict[str, float],
|
2024-03-27 13:39:26 -07:00
|
|
|
) -> Tuple[BenchmarkMetrics, List[int]]:
|
2024-06-15 12:45:31 +08:00
|
|
|
actual_output_lens: List[int] = []
|
2024-02-12 22:53:00 -08:00
|
|
|
total_input = 0
|
|
|
|
completed = 0
|
2024-10-20 11:39:32 -07:00
|
|
|
good_completed = 0
|
2024-06-15 12:45:31 +08:00
|
|
|
itls: List[float] = []
|
|
|
|
tpots: List[float] = []
|
2024-10-20 11:39:32 -07:00
|
|
|
all_tpots: List[float] = []
|
2024-06-15 12:45:31 +08:00
|
|
|
ttfts: List[float] = []
|
2024-08-29 16:48:11 -07:00
|
|
|
e2els: List[float] = []
|
2024-02-12 22:53:00 -08:00
|
|
|
for i in range(len(outputs)):
|
|
|
|
if outputs[i].success:
|
2025-01-21 21:46:14 -08:00
|
|
|
output_len = outputs[i].output_tokens
|
|
|
|
|
|
|
|
if output_len is None:
|
|
|
|
# We use the tokenizer to count the number of output tokens
|
|
|
|
# for some serving backends instead of looking at
|
|
|
|
# len(outputs[i].itl) since multiple output tokens may be
|
|
|
|
# bundled together
|
|
|
|
# Note : this may inflate the output token count slightly
|
|
|
|
output_len = len(
|
|
|
|
tokenizer(outputs[i].generated_text,
|
|
|
|
add_special_tokens=False).input_ids)
|
2024-03-27 13:39:26 -07:00
|
|
|
actual_output_lens.append(output_len)
|
2024-02-12 22:53:00 -08:00
|
|
|
total_input += input_requests[i][1]
|
2024-10-20 11:39:32 -07:00
|
|
|
tpot = 0
|
2024-03-27 13:39:26 -07:00
|
|
|
if output_len > 1:
|
2025-01-21 21:46:14 -08:00
|
|
|
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
|
|
|
|
tpot = latency_minus_ttft / (output_len - 1)
|
2024-10-20 11:39:32 -07:00
|
|
|
tpots.append(tpot)
|
|
|
|
# Note: if output_len <= 1, we regard tpot as 0 for goodput
|
|
|
|
all_tpots.append(tpot)
|
2024-06-05 13:17:51 -04:00
|
|
|
itls += outputs[i].itl
|
2024-02-12 22:53:00 -08:00
|
|
|
ttfts.append(outputs[i].ttft)
|
2024-08-29 16:48:11 -07:00
|
|
|
e2els.append(outputs[i].latency)
|
2024-02-12 22:53:00 -08:00
|
|
|
completed += 1
|
2024-03-27 13:39:26 -07:00
|
|
|
else:
|
|
|
|
actual_output_lens.append(0)
|
2023-06-14 19:55:38 -07:00
|
|
|
|
2025-01-21 21:46:14 -08:00
|
|
|
if goodput_config_dict:
|
2024-10-20 11:39:32 -07:00
|
|
|
valid_metrics = []
|
|
|
|
slo_values = []
|
|
|
|
|
2025-01-21 21:46:14 -08:00
|
|
|
if "ttft" in goodput_config_dict:
|
2024-10-20 11:39:32 -07:00
|
|
|
valid_metrics.append(ttfts)
|
2025-01-21 21:46:14 -08:00
|
|
|
slo_values.append(goodput_config_dict["ttft"] /
|
2024-10-20 11:39:32 -07:00
|
|
|
MILLISECONDS_TO_SECONDS_CONVERSION)
|
2025-01-21 21:46:14 -08:00
|
|
|
if "tpot" in goodput_config_dict:
|
2024-10-20 11:39:32 -07:00
|
|
|
valid_metrics.append(all_tpots)
|
2025-01-21 21:46:14 -08:00
|
|
|
slo_values.append(goodput_config_dict["tpot"] /
|
2024-10-20 11:39:32 -07:00
|
|
|
MILLISECONDS_TO_SECONDS_CONVERSION)
|
2025-01-21 21:46:14 -08:00
|
|
|
if "e2el" in goodput_config_dict:
|
2024-10-20 11:39:32 -07:00
|
|
|
valid_metrics.append(e2els)
|
2025-01-21 21:46:14 -08:00
|
|
|
slo_values.append(goodput_config_dict["e2el"] /
|
2024-10-20 11:39:32 -07:00
|
|
|
MILLISECONDS_TO_SECONDS_CONVERSION)
|
|
|
|
|
|
|
|
for req_metric in zip(*valid_metrics):
|
|
|
|
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
|
|
|
|
if is_good_req:
|
|
|
|
good_completed += 1
|
|
|
|
|
2024-05-25 10:28:16 -07:00
|
|
|
if completed == 0:
|
|
|
|
warnings.warn(
|
|
|
|
"All requests failed. This is likely due to a misconfiguration "
|
|
|
|
"on the benchmark arguments.",
|
|
|
|
stacklevel=2)
|
2024-02-12 22:53:00 -08:00
|
|
|
metrics = BenchmarkMetrics(
|
|
|
|
completed=completed,
|
|
|
|
total_input=total_input,
|
2024-03-27 13:39:26 -07:00
|
|
|
total_output=sum(actual_output_lens),
|
2024-02-12 22:53:00 -08:00
|
|
|
request_throughput=completed / dur_s,
|
2024-10-20 11:39:32 -07:00
|
|
|
request_goodput=good_completed / dur_s,
|
2024-03-27 13:39:26 -07:00
|
|
|
output_throughput=sum(actual_output_lens) / dur_s,
|
2024-09-04 13:23:22 -07:00
|
|
|
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
|
2024-03-27 13:39:26 -07:00
|
|
|
mean_ttft_ms=np.mean(ttfts or 0) *
|
|
|
|
1000, # ttfts is empty if streaming is not supported by backend
|
2024-07-11 13:28:38 -07:00
|
|
|
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
2024-08-29 16:48:11 -07:00
|
|
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
|
|
|
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
|
|
|
|
for p in selected_percentiles],
|
2024-05-25 10:28:16 -07:00
|
|
|
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
2024-07-11 13:28:38 -07:00
|
|
|
std_tpot_ms=np.std(tpots or 0) * 1000,
|
2024-08-29 16:48:11 -07:00
|
|
|
median_tpot_ms=np.median(tpots or 0) * 1000,
|
|
|
|
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
|
|
|
|
for p in selected_percentiles],
|
2024-06-05 13:17:51 -04:00
|
|
|
mean_itl_ms=np.mean(itls or 0) * 1000,
|
2024-07-11 13:28:38 -07:00
|
|
|
std_itl_ms=np.std(itls or 0) * 1000,
|
2024-08-29 16:48:11 -07:00
|
|
|
median_itl_ms=np.median(itls or 0) * 1000,
|
|
|
|
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
|
|
|
|
for p in selected_percentiles],
|
2024-11-04 16:37:58 +07:00
|
|
|
mean_e2el_ms=np.mean(e2els or 0) * 1000,
|
2024-08-29 16:48:11 -07:00
|
|
|
std_e2el_ms=np.std(e2els or 0) * 1000,
|
2024-11-04 16:37:58 +07:00
|
|
|
median_e2el_ms=np.median(e2els or 0) * 1000,
|
2024-08-29 16:48:11 -07:00
|
|
|
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
|
|
|
for p in selected_percentiles],
|
2024-02-12 22:53:00 -08:00
|
|
|
)
|
2023-06-14 19:55:38 -07:00
|
|
|
|
2024-03-27 13:39:26 -07:00
|
|
|
return metrics, actual_output_lens
|
2024-01-22 22:40:31 +00:00
|
|
|
|
2023-06-14 19:55:38 -07:00
|
|
|
|
|
|
|
async def benchmark(
|
|
|
|
backend: str,
|
|
|
|
api_url: str,
|
2024-08-21 15:39:26 -07:00
|
|
|
base_url: str,
|
2024-02-12 22:53:00 -08:00
|
|
|
model_id: str,
|
2025-01-19 17:59:56 +08:00
|
|
|
model_name: str,
|
2024-02-12 22:53:00 -08:00
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
2023-06-14 19:55:38 -07:00
|
|
|
input_requests: List[Tuple[str, int, int]],
|
2024-09-06 12:01:14 -04:00
|
|
|
logprobs: Optional[int],
|
2023-06-14 19:55:38 -07:00
|
|
|
best_of: int,
|
|
|
|
request_rate: float,
|
2024-11-07 19:20:30 +08:00
|
|
|
burstiness: float,
|
2024-02-12 22:53:00 -08:00
|
|
|
disable_tqdm: bool,
|
2024-08-21 15:39:26 -07:00
|
|
|
profile: bool,
|
2024-08-29 16:48:11 -07:00
|
|
|
selected_percentile_metrics: List[str],
|
|
|
|
selected_percentiles: List[str],
|
2024-10-04 14:01:44 -07:00
|
|
|
ignore_eos: bool,
|
2025-01-21 21:46:14 -08:00
|
|
|
goodput_config_dict: Dict[str, float],
|
2024-10-18 14:15:28 -04:00
|
|
|
max_concurrency: Optional[int],
|
2025-02-08 14:45:44 +05:30
|
|
|
lora_modules: Optional[List[str]],
|
2024-02-12 22:53:00 -08:00
|
|
|
):
|
|
|
|
if backend in ASYNC_REQUEST_FUNCS:
|
2024-06-15 12:45:31 +08:00
|
|
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
2024-02-12 22:53:00 -08:00
|
|
|
else:
|
|
|
|
raise ValueError(f"Unknown backend: {backend}")
|
|
|
|
|
2024-05-25 10:28:16 -07:00
|
|
|
print("Starting initial single prompt test run...")
|
2024-09-17 15:34:27 +08:00
|
|
|
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
|
|
|
|
input_requests[0])
|
|
|
|
if backend != "openai-chat" and test_mm_content is not None:
|
|
|
|
# multi-modal benchmark is only available on OpenAI Chat backend.
|
|
|
|
raise ValueError(
|
|
|
|
"Multi-modal content is only supported on 'openai-chat' backend.")
|
2024-05-25 10:28:16 -07:00
|
|
|
test_input = RequestFuncInput(
|
|
|
|
model=model_id,
|
2025-01-19 17:59:56 +08:00
|
|
|
model_name=model_name,
|
2024-05-25 10:28:16 -07:00
|
|
|
prompt=test_prompt,
|
|
|
|
api_url=api_url,
|
|
|
|
prompt_len=test_prompt_len,
|
|
|
|
output_len=test_output_len,
|
2024-09-06 12:01:14 -04:00
|
|
|
logprobs=logprobs,
|
2024-05-25 10:28:16 -07:00
|
|
|
best_of=best_of,
|
2024-09-17 15:34:27 +08:00
|
|
|
multi_modal_content=test_mm_content,
|
2024-10-04 14:01:44 -07:00
|
|
|
ignore_eos=ignore_eos,
|
2024-05-25 10:28:16 -07:00
|
|
|
)
|
2025-02-08 14:45:44 +05:30
|
|
|
|
2024-05-25 10:28:16 -07:00
|
|
|
test_output = await request_func(request_func_input=test_input)
|
|
|
|
if not test_output.success:
|
|
|
|
raise ValueError(
|
|
|
|
"Initial test run failed - Please make sure benchmark arguments "
|
|
|
|
f"are correctly specified. Error: {test_output.error}")
|
|
|
|
else:
|
|
|
|
print("Initial test run completed. Starting main benchmark run...")
|
2024-08-21 15:39:26 -07:00
|
|
|
|
2025-02-08 14:45:44 +05:30
|
|
|
if lora_modules:
|
|
|
|
# For each input request, choose a LoRA module at random.
|
|
|
|
lora_modules = iter(
|
|
|
|
[random.choice(lora_modules) for _ in range(len(input_requests))])
|
|
|
|
|
2024-08-21 15:39:26 -07:00
|
|
|
if profile:
|
|
|
|
print("Starting profiler...")
|
2024-10-15 13:30:44 -07:00
|
|
|
profile_input = RequestFuncInput(model=model_id,
|
2025-01-19 17:59:56 +08:00
|
|
|
model_name=model_name,
|
2024-10-15 13:30:44 -07:00
|
|
|
prompt=test_prompt,
|
|
|
|
api_url=base_url + "/start_profile",
|
|
|
|
prompt_len=test_prompt_len,
|
|
|
|
output_len=test_output_len,
|
|
|
|
logprobs=logprobs,
|
|
|
|
best_of=best_of,
|
|
|
|
multi_modal_content=test_mm_content,
|
|
|
|
ignore_eos=ignore_eos)
|
2024-08-21 15:39:26 -07:00
|
|
|
profile_output = await request_func(request_func_input=profile_input)
|
|
|
|
if profile_output.success:
|
|
|
|
print("Profiler started")
|
|
|
|
|
2024-11-07 19:20:30 +08:00
|
|
|
if burstiness == 1.0:
|
|
|
|
distribution = "Poisson process"
|
|
|
|
else:
|
|
|
|
distribution = "Gamma distribution"
|
|
|
|
|
2024-02-12 22:53:00 -08:00
|
|
|
print(f"Traffic request rate: {request_rate}")
|
2024-11-07 19:20:30 +08:00
|
|
|
print(f"Burstiness factor: {burstiness} ({distribution})")
|
2024-10-18 14:15:28 -04:00
|
|
|
print(f"Maximum request concurrency: {max_concurrency}")
|
2024-02-12 22:53:00 -08:00
|
|
|
|
2024-03-08 14:22:59 +08:00
|
|
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
|
|
|
|
2024-10-18 14:15:28 -04:00
|
|
|
# This can be used once the minimum Python version is 3.10 or higher,
|
|
|
|
# and it will simplify the code in limited_request_func.
|
|
|
|
# semaphore = (asyncio.Semaphore(max_concurrency)
|
|
|
|
# if max_concurrency else contextlib.nullcontext())
|
|
|
|
semaphore = (asyncio.Semaphore(max_concurrency)
|
|
|
|
if max_concurrency else None)
|
|
|
|
|
|
|
|
async def limited_request_func(request_func_input, pbar):
|
|
|
|
if semaphore is None:
|
|
|
|
return await request_func(request_func_input=request_func_input,
|
|
|
|
pbar=pbar)
|
|
|
|
async with semaphore:
|
|
|
|
return await request_func(request_func_input=request_func_input,
|
|
|
|
pbar=pbar)
|
|
|
|
|
2024-02-12 22:53:00 -08:00
|
|
|
benchmark_start_time = time.perf_counter()
|
2024-06-15 12:45:31 +08:00
|
|
|
tasks: List[asyncio.Task] = []
|
2024-11-07 19:20:30 +08:00
|
|
|
async for request in get_request(input_requests, request_rate, burstiness):
|
2024-09-17 15:34:27 +08:00
|
|
|
prompt, prompt_len, output_len, mm_content = request
|
2025-02-08 14:45:44 +05:30
|
|
|
req_model_id, req_model_name = model_id, model_name
|
|
|
|
if lora_modules:
|
|
|
|
req_lora_module = next(lora_modules)
|
|
|
|
req_model_id, req_model_name = req_lora_module, req_lora_module
|
|
|
|
|
|
|
|
request_func_input = RequestFuncInput(model=req_model_id,
|
|
|
|
model_name=req_model_name,
|
2024-10-15 13:30:44 -07:00
|
|
|
prompt=prompt,
|
|
|
|
api_url=api_url,
|
|
|
|
prompt_len=prompt_len,
|
|
|
|
output_len=output_len,
|
|
|
|
logprobs=logprobs,
|
|
|
|
best_of=best_of,
|
|
|
|
multi_modal_content=mm_content,
|
|
|
|
ignore_eos=ignore_eos)
|
2024-02-12 22:53:00 -08:00
|
|
|
tasks.append(
|
|
|
|
asyncio.create_task(
|
2024-10-18 14:15:28 -04:00
|
|
|
limited_request_func(request_func_input=request_func_input,
|
|
|
|
pbar=pbar)))
|
2024-03-27 13:39:26 -07:00
|
|
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
2024-02-12 22:53:00 -08:00
|
|
|
|
2024-08-21 15:39:26 -07:00
|
|
|
if profile:
|
|
|
|
print("Stopping profiler...")
|
|
|
|
profile_input = RequestFuncInput(
|
|
|
|
model=model_id,
|
|
|
|
prompt=test_prompt,
|
|
|
|
api_url=base_url + "/stop_profile",
|
|
|
|
prompt_len=test_prompt_len,
|
|
|
|
output_len=test_output_len,
|
2024-09-06 12:01:14 -04:00
|
|
|
logprobs=logprobs,
|
2024-08-21 15:39:26 -07:00
|
|
|
best_of=best_of,
|
|
|
|
)
|
|
|
|
profile_output = await request_func(request_func_input=profile_input)
|
|
|
|
if profile_output.success:
|
|
|
|
print("Profiler stopped")
|
|
|
|
|
2024-06-15 12:45:31 +08:00
|
|
|
if pbar is not None:
|
2024-02-12 22:53:00 -08:00
|
|
|
pbar.close()
|
|
|
|
|
|
|
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
|
|
|
|
2024-03-27 13:39:26 -07:00
|
|
|
metrics, actual_output_lens = calculate_metrics(
|
2024-02-12 22:53:00 -08:00
|
|
|
input_requests=input_requests,
|
|
|
|
outputs=outputs,
|
|
|
|
dur_s=benchmark_duration,
|
|
|
|
tokenizer=tokenizer,
|
2024-08-29 16:48:11 -07:00
|
|
|
selected_percentile_metrics=selected_percentile_metrics,
|
|
|
|
selected_percentiles=selected_percentiles,
|
2025-01-21 21:46:14 -08:00
|
|
|
goodput_config_dict=goodput_config_dict,
|
2024-02-12 22:53:00 -08:00
|
|
|
)
|
|
|
|
|
2024-03-27 13:39:26 -07:00
|
|
|
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
|
|
|
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
|
|
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
|
|
|
|
benchmark_duration))
|
|
|
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
|
|
|
print("{:<40} {:<10}".format("Total generated tokens:",
|
|
|
|
metrics.total_output))
|
|
|
|
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
|
|
|
metrics.request_throughput))
|
2025-01-21 21:46:14 -08:00
|
|
|
if goodput_config_dict:
|
2024-10-20 11:39:32 -07:00
|
|
|
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
|
|
|
metrics.request_goodput))
|
2024-03-27 13:39:26 -07:00
|
|
|
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
|
|
|
metrics.output_throughput))
|
2024-09-04 13:23:22 -07:00
|
|
|
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
|
|
|
metrics.total_token_throughput))
|
2024-02-12 22:53:00 -08:00
|
|
|
|
|
|
|
result = {
|
|
|
|
"duration": benchmark_duration,
|
|
|
|
"completed": metrics.completed,
|
|
|
|
"total_input_tokens": metrics.total_input,
|
|
|
|
"total_output_tokens": metrics.total_output,
|
2024-03-27 13:39:26 -07:00
|
|
|
"request_throughput": metrics.request_throughput,
|
2024-10-20 11:39:32 -07:00
|
|
|
"request_goodput:":
|
2025-01-21 21:46:14 -08:00
|
|
|
metrics.request_goodput if goodput_config_dict else None,
|
2024-02-12 22:53:00 -08:00
|
|
|
"output_throughput": metrics.output_throughput,
|
2024-09-04 13:23:22 -07:00
|
|
|
"total_token_throughput": metrics.total_token_throughput,
|
2024-03-27 13:39:26 -07:00
|
|
|
"input_lens": [output.prompt_len for output in outputs],
|
|
|
|
"output_lens": actual_output_lens,
|
|
|
|
"ttfts": [output.ttft for output in outputs],
|
|
|
|
"itls": [output.itl for output in outputs],
|
|
|
|
"generated_texts": [output.generated_text for output in outputs],
|
|
|
|
"errors": [output.error for output in outputs],
|
2024-02-12 22:53:00 -08:00
|
|
|
}
|
2024-08-29 16:48:11 -07:00
|
|
|
|
|
|
|
def process_one_metric(
|
|
|
|
# E.g., "ttft"
|
|
|
|
metric_attribute_name: str,
|
|
|
|
# E.g., "TTFT"
|
|
|
|
metric_name: str,
|
|
|
|
# E.g., "Time to First Token"
|
|
|
|
metric_header: str,
|
|
|
|
):
|
2024-10-10 17:33:16 -07:00
|
|
|
# This function prints and adds statistics of the specified
|
2024-08-29 16:48:11 -07:00
|
|
|
# metric.
|
|
|
|
if metric_attribute_name not in selected_percentile_metrics:
|
|
|
|
return
|
|
|
|
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
|
|
|
print("{:<40} {:<10.2f}".format(
|
|
|
|
f"Mean {metric_name} (ms):",
|
|
|
|
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
|
|
|
|
print("{:<40} {:<10.2f}".format(
|
|
|
|
f"Median {metric_name} (ms):",
|
|
|
|
getattr(metrics, f"median_{metric_attribute_name}_ms")))
|
|
|
|
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
|
|
|
metrics, f"mean_{metric_attribute_name}_ms")
|
|
|
|
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
|
|
|
metrics, f"median_{metric_attribute_name}_ms")
|
|
|
|
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
|
|
|
metrics, f"std_{metric_attribute_name}_ms")
|
|
|
|
for p, value in getattr(metrics,
|
|
|
|
f"percentiles_{metric_attribute_name}_ms"):
|
|
|
|
p_word = str(int(p)) if int(p) == p else str(p)
|
|
|
|
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
|
|
|
|
value))
|
|
|
|
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
|
|
|
|
|
|
|
process_one_metric("ttft", "TTFT", "Time to First Token")
|
|
|
|
process_one_metric("tpot", "TPOT",
|
|
|
|
"Time per Output Token (excl. 1st token)")
|
|
|
|
process_one_metric("itl", "ITL", "Inter-token Latency")
|
|
|
|
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
|
|
|
|
|
|
|
print("=" * 50)
|
|
|
|
|
2024-02-12 22:53:00 -08:00
|
|
|
return result
|
2023-06-14 19:55:38 -07:00
|
|
|
|
|
|
|
|
2024-10-20 11:39:32 -07:00
|
|
|
def check_goodput_args(args):
|
|
|
|
# Check and parse goodput arguments
|
2025-01-21 21:46:14 -08:00
|
|
|
goodput_config_dict = {}
|
2024-10-20 11:39:32 -07:00
|
|
|
VALID_NAMES = ["ttft", "tpot", "e2el"]
|
|
|
|
if args.goodput:
|
2025-01-21 21:46:14 -08:00
|
|
|
goodput_config_dict = parse_goodput(args.goodput)
|
|
|
|
for slo_name, slo_val in goodput_config_dict.items():
|
2024-10-20 11:39:32 -07:00
|
|
|
if slo_name not in VALID_NAMES:
|
|
|
|
raise ValueError(
|
|
|
|
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
|
|
|
"The service level objective name should be one of "
|
|
|
|
f"{str(VALID_NAMES)}. ")
|
|
|
|
if slo_val < 0:
|
|
|
|
raise ValueError(
|
|
|
|
f"Invalid value found, {slo_name}: {slo_val}. "
|
|
|
|
"The service level objective value should be "
|
|
|
|
"non-negative.")
|
2025-01-21 21:46:14 -08:00
|
|
|
return goodput_config_dict
|
2024-10-20 11:39:32 -07:00
|
|
|
|
|
|
|
|
|
|
|
def parse_goodput(slo_pairs):
|
2025-01-21 21:46:14 -08:00
|
|
|
goodput_config_dict = {}
|
2024-10-20 11:39:32 -07:00
|
|
|
try:
|
|
|
|
for slo_pair in slo_pairs:
|
|
|
|
slo_name, slo_val = slo_pair.split(":")
|
2025-01-21 21:46:14 -08:00
|
|
|
goodput_config_dict[slo_name] = float(slo_val)
|
2024-10-20 11:39:32 -07:00
|
|
|
except ValueError as err:
|
|
|
|
raise argparse.ArgumentTypeError(
|
|
|
|
"Invalid format found for service level objectives. "
|
|
|
|
"Specify service level objectives for goodput as \"KEY:VALUE\" "
|
|
|
|
"pairs, where the key is a metric name, and the value is a "
|
|
|
|
"number in milliseconds.") from err
|
2025-01-21 21:46:14 -08:00
|
|
|
return goodput_config_dict
|
2024-10-20 11:39:32 -07:00
|
|
|
|
|
|
|
|
2023-06-14 19:55:38 -07:00
|
|
|
def main(args: argparse.Namespace):
|
|
|
|
print(args)
|
|
|
|
random.seed(args.seed)
|
|
|
|
np.random.seed(args.seed)
|
|
|
|
|
2024-02-12 22:53:00 -08:00
|
|
|
backend = args.backend
|
|
|
|
model_id = args.model
|
2025-01-19 17:59:56 +08:00
|
|
|
model_name = args.served_model_name
|
2024-02-12 22:53:00 -08:00
|
|
|
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
2024-12-13 11:19:10 -05:00
|
|
|
tokenizer_mode = args.tokenizer_mode
|
2024-02-12 22:53:00 -08:00
|
|
|
|
|
|
|
if args.base_url is not None:
|
|
|
|
api_url = f"{args.base_url}{args.endpoint}"
|
2024-08-21 15:39:26 -07:00
|
|
|
base_url = f"{args.base_url}"
|
2024-02-12 22:53:00 -08:00
|
|
|
else:
|
|
|
|
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
2024-08-21 15:39:26 -07:00
|
|
|
base_url = f"http://{args.host}:{args.port}"
|
2024-02-12 22:53:00 -08:00
|
|
|
|
|
|
|
tokenizer = get_tokenizer(tokenizer_id,
|
2024-12-13 11:19:10 -05:00
|
|
|
tokenizer_mode=tokenizer_mode,
|
2024-01-19 04:34:08 +00:00
|
|
|
trust_remote_code=args.trust_remote_code)
|
2024-03-27 13:39:26 -07:00
|
|
|
|
|
|
|
if args.dataset is not None:
|
|
|
|
warnings.warn(
|
|
|
|
"The '--dataset' argument will be deprecated in the next "
|
|
|
|
"release. Please use '--dataset-name' and "
|
|
|
|
"'--dataset-path' in the future runs.",
|
|
|
|
stacklevel=2)
|
|
|
|
input_requests = sample_sharegpt_requests(
|
|
|
|
dataset_path=args.dataset,
|
|
|
|
num_requests=args.num_prompts,
|
|
|
|
tokenizer=tokenizer,
|
2024-04-24 09:49:13 -07:00
|
|
|
fixed_output_len=args.sharegpt_output_len,
|
2024-03-27 13:39:26 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
elif args.dataset_name == "sharegpt":
|
|
|
|
input_requests = sample_sharegpt_requests(
|
|
|
|
dataset_path=args.dataset_path,
|
|
|
|
num_requests=args.num_prompts,
|
|
|
|
tokenizer=tokenizer,
|
2024-04-24 09:49:13 -07:00
|
|
|
fixed_output_len=args.sharegpt_output_len,
|
2024-03-27 13:39:26 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
elif args.dataset_name == "sonnet":
|
|
|
|
# Do not format the prompt, pass to message directly
|
|
|
|
if args.backend == "openai-chat":
|
|
|
|
input_requests = sample_sonnet_requests(
|
|
|
|
dataset_path=args.dataset_path,
|
|
|
|
num_requests=args.num_prompts,
|
2024-04-04 00:41:05 -07:00
|
|
|
input_len=args.sonnet_input_len,
|
|
|
|
output_len=args.sonnet_output_len,
|
|
|
|
prefix_len=args.sonnet_prefix_len,
|
2024-03-27 13:39:26 -07:00
|
|
|
tokenizer=tokenizer,
|
|
|
|
)
|
2024-09-18 22:24:24 -07:00
|
|
|
input_requests = [(prompt, prompt_len, output_len, None)
|
2024-03-27 13:39:26 -07:00
|
|
|
for prompt, prompt_formatted, prompt_len,
|
2024-09-18 22:24:24 -07:00
|
|
|
output_len, _ in input_requests]
|
2024-03-27 13:39:26 -07:00
|
|
|
else:
|
|
|
|
assert (
|
|
|
|
tokenizer.chat_template or tokenizer.default_chat_template
|
|
|
|
), "Tokenizer/model must have chat template for sonnet dataset."
|
|
|
|
input_requests = sample_sonnet_requests(
|
|
|
|
dataset_path=args.dataset_path,
|
|
|
|
num_requests=args.num_prompts,
|
2024-04-04 00:41:05 -07:00
|
|
|
input_len=args.sonnet_input_len,
|
|
|
|
output_len=args.sonnet_output_len,
|
|
|
|
prefix_len=args.sonnet_prefix_len,
|
2024-03-27 13:39:26 -07:00
|
|
|
tokenizer=tokenizer,
|
|
|
|
)
|
2024-09-18 22:24:24 -07:00
|
|
|
input_requests = [(prompt_formatted, prompt_len, output_len, None)
|
2024-03-27 13:39:26 -07:00
|
|
|
for prompt, prompt_formatted, prompt_len,
|
2024-09-18 22:24:24 -07:00
|
|
|
output_len, _ in input_requests]
|
2024-03-27 13:39:26 -07:00
|
|
|
|
2024-09-17 15:34:27 +08:00
|
|
|
elif args.dataset_name == "hf":
|
|
|
|
input_requests = sample_hf_requests(
|
|
|
|
dataset_path=args.dataset_path,
|
|
|
|
dataset_subset=args.hf_subset,
|
|
|
|
dataset_split=args.hf_split,
|
|
|
|
num_requests=args.num_prompts,
|
|
|
|
tokenizer=tokenizer,
|
2024-10-17 14:11:11 -07:00
|
|
|
random_seed=args.seed,
|
2024-09-17 15:34:27 +08:00
|
|
|
fixed_output_len=args.hf_output_len,
|
|
|
|
)
|
|
|
|
|
2024-07-07 15:42:13 +08:00
|
|
|
elif args.dataset_name == "random":
|
|
|
|
input_requests = sample_random_requests(
|
2024-09-06 20:18:16 -07:00
|
|
|
prefix_len=args.random_prefix_len,
|
2024-07-08 15:52:06 +08:00
|
|
|
input_len=args.random_input_len,
|
|
|
|
output_len=args.random_output_len,
|
2024-07-07 15:42:13 +08:00
|
|
|
num_prompts=args.num_prompts,
|
2024-07-08 15:52:06 +08:00
|
|
|
range_ratio=args.random_range_ratio,
|
2024-07-07 15:42:13 +08:00
|
|
|
tokenizer=tokenizer,
|
|
|
|
)
|
|
|
|
|
2024-03-27 13:39:26 -07:00
|
|
|
else:
|
|
|
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
2023-06-14 19:55:38 -07:00
|
|
|
|
2025-01-21 21:46:14 -08:00
|
|
|
goodput_config_dict = check_goodput_args(args)
|
|
|
|
|
|
|
|
# Avoid GC processing "static" data - reduce pause times.
|
|
|
|
gc.collect()
|
|
|
|
gc.freeze()
|
2024-10-20 11:39:32 -07:00
|
|
|
|
2024-02-12 22:53:00 -08:00
|
|
|
benchmark_result = asyncio.run(
|
|
|
|
benchmark(
|
|
|
|
backend=backend,
|
|
|
|
api_url=api_url,
|
2024-08-21 15:39:26 -07:00
|
|
|
base_url=base_url,
|
2024-02-12 22:53:00 -08:00
|
|
|
model_id=model_id,
|
2025-01-19 17:59:56 +08:00
|
|
|
model_name=model_name,
|
2024-02-12 22:53:00 -08:00
|
|
|
tokenizer=tokenizer,
|
|
|
|
input_requests=input_requests,
|
2024-09-06 12:01:14 -04:00
|
|
|
logprobs=args.logprobs,
|
2024-02-12 22:53:00 -08:00
|
|
|
best_of=args.best_of,
|
|
|
|
request_rate=args.request_rate,
|
2024-11-07 19:20:30 +08:00
|
|
|
burstiness=args.burstiness,
|
2024-02-12 22:53:00 -08:00
|
|
|
disable_tqdm=args.disable_tqdm,
|
2024-08-21 15:39:26 -07:00
|
|
|
profile=args.profile,
|
2024-08-29 16:48:11 -07:00
|
|
|
selected_percentile_metrics=args.percentile_metrics.split(","),
|
|
|
|
selected_percentiles=[
|
|
|
|
float(p) for p in args.metric_percentiles.split(",")
|
|
|
|
],
|
2024-10-04 14:01:44 -07:00
|
|
|
ignore_eos=args.ignore_eos,
|
2025-01-21 21:46:14 -08:00
|
|
|
goodput_config_dict=goodput_config_dict,
|
2024-10-18 14:15:28 -04:00
|
|
|
max_concurrency=args.max_concurrency,
|
2025-02-08 14:45:44 +05:30
|
|
|
lora_modules=args.lora_modules,
|
2024-02-12 22:53:00 -08:00
|
|
|
))
|
|
|
|
|
|
|
|
# Save config and results to json
|
|
|
|
if args.save_result:
|
2024-06-15 12:45:31 +08:00
|
|
|
result_json: Dict[str, Any] = {}
|
2024-02-12 22:53:00 -08:00
|
|
|
|
|
|
|
# Setup
|
|
|
|
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
|
|
result_json["date"] = current_dt
|
|
|
|
result_json["backend"] = backend
|
|
|
|
result_json["model_id"] = model_id
|
|
|
|
result_json["tokenizer_id"] = tokenizer_id
|
|
|
|
result_json["best_of"] = args.best_of
|
|
|
|
result_json["num_prompts"] = args.num_prompts
|
|
|
|
|
2024-03-27 13:39:26 -07:00
|
|
|
# Metadata
|
|
|
|
if args.metadata:
|
|
|
|
for item in args.metadata:
|
|
|
|
if "=" in item:
|
|
|
|
kvstring = item.split("=")
|
|
|
|
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
"Invalid metadata format. Please use KEY=VALUE format."
|
|
|
|
)
|
|
|
|
|
2024-02-12 22:53:00 -08:00
|
|
|
# Traffic
|
2025-01-28 00:23:08 +00:00
|
|
|
result_json["request_rate"] = (args.request_rate if args.request_rate
|
|
|
|
< float("inf") else "inf")
|
2024-11-07 19:20:30 +08:00
|
|
|
result_json["burstiness"] = args.burstiness
|
2024-10-18 14:15:28 -04:00
|
|
|
result_json["max_concurrency"] = args.max_concurrency
|
2024-02-12 22:53:00 -08:00
|
|
|
|
|
|
|
# Merge with benchmark result
|
|
|
|
result_json = {**result_json, **benchmark_result}
|
|
|
|
|
|
|
|
# Save to file
|
|
|
|
base_model_id = model_id.split("/")[-1]
|
2024-10-18 14:15:28 -04:00
|
|
|
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
|
|
|
if args.max_concurrency is not None else "")
|
|
|
|
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
|
2024-06-13 22:36:20 -07:00
|
|
|
if args.result_filename:
|
|
|
|
file_name = args.result_filename
|
2024-03-27 13:39:26 -07:00
|
|
|
if args.result_dir:
|
|
|
|
file_name = os.path.join(args.result_dir, file_name)
|
2024-10-01 21:07:06 +03:00
|
|
|
with open(file_name, "w", encoding='utf-8') as outfile:
|
2024-02-12 22:53:00 -08:00
|
|
|
json.dump(result_json, outfile)
|
2023-06-14 19:55:38 -07:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-06-20 19:00:13 -04:00
|
|
|
parser = FlexibleArgumentParser(
|
2023-06-14 19:55:38 -07:00
|
|
|
description="Benchmark the online serving throughput.")
|
2024-02-12 22:53:00 -08:00
|
|
|
parser.add_argument(
|
|
|
|
"--backend",
|
|
|
|
type=str,
|
|
|
|
default="vllm",
|
|
|
|
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--base-url",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="Server or API base url if not using http host and port.",
|
|
|
|
)
|
2023-06-14 19:55:38 -07:00
|
|
|
parser.add_argument("--host", type=str, default="localhost")
|
2023-06-26 13:15:35 -07:00
|
|
|
parser.add_argument("--port", type=int, default=8000)
|
2024-02-12 22:53:00 -08:00
|
|
|
parser.add_argument(
|
|
|
|
"--endpoint",
|
|
|
|
type=str,
|
2024-03-27 13:39:26 -07:00
|
|
|
default="/v1/completions",
|
2024-02-12 22:53:00 -08:00
|
|
|
help="API endpoint.",
|
|
|
|
)
|
2024-03-27 13:39:26 -07:00
|
|
|
parser.add_argument(
|
|
|
|
"--dataset",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="Path to the ShareGPT dataset, will be deprecated in the "
|
|
|
|
"next release.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--dataset-name",
|
|
|
|
type=str,
|
|
|
|
default="sharegpt",
|
2024-09-17 15:34:27 +08:00
|
|
|
choices=["sharegpt", "sonnet", "random", "hf"],
|
2024-03-27 13:39:26 -07:00
|
|
|
help="Name of the dataset to benchmark on.",
|
|
|
|
)
|
|
|
|
parser.add_argument("--dataset-path",
|
2024-01-19 04:34:08 +00:00
|
|
|
type=str,
|
2024-03-27 13:39:26 -07:00
|
|
|
default=None,
|
2024-09-17 15:34:27 +08:00
|
|
|
help="Path to the sharegpt/sonnet dataset. "
|
|
|
|
"Or the huggingface dataset ID if using HF dataset.")
|
2024-10-18 14:15:28 -04:00
|
|
|
parser.add_argument(
|
|
|
|
"--max-concurrency",
|
|
|
|
type=int,
|
|
|
|
default=None,
|
|
|
|
help="Maximum number of concurrent requests. This can be used "
|
|
|
|
"to help simulate an environment where a higher level component "
|
|
|
|
"is enforcing a maximum number of concurrent requests. While the "
|
|
|
|
"--request-rate argument controls the rate at which requests are "
|
|
|
|
"initiated, this argument will control how many are actually allowed "
|
|
|
|
"to execute at a time. This means that when used in combination, the "
|
|
|
|
"actual request rate may be lower than specified with --request-rate, "
|
|
|
|
"if the server is not processing requests fast enough to keep up.")
|
|
|
|
|
2024-02-12 22:53:00 -08:00
|
|
|
parser.add_argument(
|
|
|
|
"--model",
|
|
|
|
type=str,
|
|
|
|
required=True,
|
|
|
|
help="Name of the model.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--tokenizer",
|
|
|
|
type=str,
|
|
|
|
help=
|
2024-07-07 15:42:13 +08:00
|
|
|
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
2024-02-12 22:53:00 -08:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--best-of",
|
|
|
|
type=int,
|
|
|
|
default=1,
|
|
|
|
help="Generates `best_of` sequences per prompt and "
|
|
|
|
"returns the best one.",
|
|
|
|
)
|
2023-06-14 19:55:38 -07:00
|
|
|
parser.add_argument("--use-beam-search", action="store_true")
|
2024-02-12 22:53:00 -08:00
|
|
|
parser.add_argument(
|
|
|
|
"--num-prompts",
|
|
|
|
type=int,
|
|
|
|
default=1000,
|
|
|
|
help="Number of prompts to process.",
|
|
|
|
)
|
2024-09-06 12:01:14 -04:00
|
|
|
parser.add_argument(
|
|
|
|
"--logprobs",
|
|
|
|
type=int,
|
|
|
|
default=None,
|
|
|
|
help=("Number of logprobs-per-token to compute & return as part of "
|
|
|
|
"the request. If unspecified, then either (1) if beam search "
|
|
|
|
"is disabled, no logprobs are computed & a single dummy "
|
|
|
|
"logprob is returned for each token; or (2) if beam search "
|
|
|
|
"is enabled 1 logprob per token is computed"),
|
|
|
|
)
|
2024-02-12 22:53:00 -08:00
|
|
|
parser.add_argument(
|
|
|
|
"--request-rate",
|
|
|
|
type=float,
|
|
|
|
default=float("inf"),
|
|
|
|
help="Number of requests per second. If this is inf, "
|
|
|
|
"then all the requests are sent at time 0. "
|
2024-11-07 19:20:30 +08:00
|
|
|
"Otherwise, we use Poisson process or gamma distribution "
|
|
|
|
"to synthesize the request arrival times.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--burstiness",
|
|
|
|
type=float,
|
|
|
|
default=1.0,
|
|
|
|
help="Burstiness factor of the request generation. "
|
|
|
|
"Only take effect when request_rate is not inf. "
|
|
|
|
"Default value is 1, which follows Poisson process. "
|
|
|
|
"Otherwise, the request intervals follow a gamma distribution. "
|
|
|
|
"A lower burstiness value (0 < burstiness < 1) results in more "
|
|
|
|
"bursty requests. A higher burstiness value (burstiness > 1) "
|
|
|
|
"results in a more uniform arrival of requests.",
|
2024-02-12 22:53:00 -08:00
|
|
|
)
|
2023-06-14 19:55:38 -07:00
|
|
|
parser.add_argument("--seed", type=int, default=0)
|
2024-02-12 22:53:00 -08:00
|
|
|
parser.add_argument(
|
|
|
|
"--trust-remote-code",
|
|
|
|
action="store_true",
|
|
|
|
help="Trust remote code from huggingface",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--disable-tqdm",
|
|
|
|
action="store_true",
|
2024-02-22 02:56:01 +00:00
|
|
|
help="Specify to disable tqdm progress bar.",
|
2024-02-12 22:53:00 -08:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
2024-08-21 15:39:26 -07:00
|
|
|
"--profile",
|
|
|
|
action="store_true",
|
|
|
|
help="Use Torch Profiler. The endpoint must be launched with "
|
|
|
|
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
2024-02-12 22:53:00 -08:00
|
|
|
"--save-result",
|
|
|
|
action="store_true",
|
|
|
|
help="Specify to save benchmark results to a json file",
|
|
|
|
)
|
2024-03-27 13:39:26 -07:00
|
|
|
parser.add_argument(
|
|
|
|
"--metadata",
|
|
|
|
metavar="KEY=VALUE",
|
|
|
|
nargs="*",
|
|
|
|
help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
|
|
|
|
"for metadata of this run to be saved in the result JSON file "
|
|
|
|
"for record keeping purposes.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--result-dir",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="Specify directory to save benchmark json results."
|
|
|
|
"If not specified, results are saved in the current directory.",
|
|
|
|
)
|
2024-06-13 22:36:20 -07:00
|
|
|
parser.add_argument(
|
|
|
|
"--result-filename",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="Specify the filename to save benchmark json results."
|
|
|
|
"If not specified, results will be saved in "
|
|
|
|
"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
|
|
|
|
" format.",
|
|
|
|
)
|
2024-10-04 14:01:44 -07:00
|
|
|
parser.add_argument(
|
|
|
|
"--ignore-eos",
|
|
|
|
action="store_true",
|
|
|
|
help="Set ignore_eos flag when sending the benchmark request."
|
|
|
|
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
|
2024-08-29 16:48:11 -07:00
|
|
|
parser.add_argument(
|
|
|
|
"--percentile-metrics",
|
|
|
|
type=str,
|
|
|
|
default="ttft,tpot,itl",
|
|
|
|
help="Comma-seperated list of selected metrics to report percentils. "
|
|
|
|
"This argument specifies the metrics to report percentiles. "
|
|
|
|
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
|
|
|
"Default value is \"ttft,tpot,itl\".")
|
|
|
|
parser.add_argument(
|
|
|
|
"--metric-percentiles",
|
|
|
|
type=str,
|
|
|
|
default="99",
|
|
|
|
help="Comma-seperated list of percentiles for selected metrics. "
|
|
|
|
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
|
|
|
"Default value is \"99\". "
|
|
|
|
"Use \"--percentile-metrics\" to select metrics.",
|
|
|
|
)
|
2024-10-20 11:39:32 -07:00
|
|
|
parser.add_argument(
|
|
|
|
"--goodput",
|
|
|
|
nargs="+",
|
|
|
|
required=False,
|
|
|
|
help="Specify service level objectives for goodput as \"KEY:VALUE\" "
|
|
|
|
"pairs, where the key is a metric name, and the value is in "
|
|
|
|
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
|
|
|
|
"separated by spaces. Allowed request level metric names are "
|
|
|
|
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
|
|
|
|
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
|
|
|
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
|
2024-02-12 22:53:00 -08:00
|
|
|
|
2024-09-17 15:34:27 +08:00
|
|
|
# group for dataset specific arguments
|
|
|
|
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
|
|
|
sonnet_group.add_argument(
|
|
|
|
"--sonnet-input-len",
|
|
|
|
type=int,
|
|
|
|
default=550,
|
|
|
|
help=
|
|
|
|
"Number of input tokens per request, used only for sonnet dataset.",
|
|
|
|
)
|
|
|
|
sonnet_group.add_argument(
|
|
|
|
"--sonnet-output-len",
|
|
|
|
type=int,
|
|
|
|
default=150,
|
|
|
|
help=
|
|
|
|
"Number of output tokens per request, used only for sonnet dataset.",
|
|
|
|
)
|
|
|
|
sonnet_group.add_argument(
|
|
|
|
"--sonnet-prefix-len",
|
|
|
|
type=int,
|
|
|
|
default=200,
|
|
|
|
help=
|
|
|
|
"Number of prefix tokens per request, used only for sonnet dataset.",
|
|
|
|
)
|
|
|
|
|
|
|
|
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
|
|
|
|
sharegpt_group.add_argument(
|
|
|
|
"--sharegpt-output-len",
|
|
|
|
type=int,
|
|
|
|
default=None,
|
|
|
|
help="Output length for each request. Overrides the output length "
|
|
|
|
"from the ShareGPT dataset.")
|
|
|
|
|
|
|
|
random_group = parser.add_argument_group("random dataset options")
|
|
|
|
random_group.add_argument(
|
|
|
|
"--random-input-len",
|
|
|
|
type=int,
|
|
|
|
default=1024,
|
|
|
|
help=
|
|
|
|
"Number of input tokens per request, used only for random sampling.",
|
|
|
|
)
|
|
|
|
random_group.add_argument(
|
|
|
|
"--random-output-len",
|
|
|
|
type=int,
|
|
|
|
default=128,
|
|
|
|
help=
|
|
|
|
"Number of output tokens per request, used only for random sampling.",
|
|
|
|
)
|
|
|
|
random_group.add_argument(
|
|
|
|
"--random-range-ratio",
|
|
|
|
type=float,
|
|
|
|
default=1.0,
|
|
|
|
help="Range of sampled ratio of input/output length, "
|
|
|
|
"used only for random sampling.",
|
|
|
|
)
|
|
|
|
random_group.add_argument(
|
|
|
|
"--random-prefix-len",
|
|
|
|
type=int,
|
|
|
|
default=0,
|
|
|
|
help="Number of fixed prefix tokens before random "
|
|
|
|
" context. The length range of context in a random "
|
|
|
|
" request is [random-prefix-len, "
|
|
|
|
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
|
|
|
|
|
|
|
hf_group = parser.add_argument_group("hf dataset options")
|
|
|
|
hf_group.add_argument("--hf-subset",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="Subset of the HF dataset.")
|
|
|
|
hf_group.add_argument("--hf-split",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="Split of the HF dataset.")
|
|
|
|
hf_group.add_argument(
|
|
|
|
"--hf-output-len",
|
|
|
|
type=int,
|
|
|
|
default=None,
|
|
|
|
help="Output length for each request. Overrides the output lengths "
|
|
|
|
"from the sampled HF dataset.",
|
|
|
|
)
|
|
|
|
|
2024-12-13 11:19:10 -05:00
|
|
|
parser.add_argument(
|
|
|
|
'--tokenizer-mode',
|
|
|
|
type=str,
|
|
|
|
default="auto",
|
|
|
|
choices=['auto', 'slow', 'mistral'],
|
|
|
|
help='The tokenizer mode.\n\n* "auto" will use the '
|
|
|
|
'fast tokenizer if available.\n* "slow" will '
|
|
|
|
'always use the slow tokenizer. \n* '
|
|
|
|
'"mistral" will always use the `mistral_common` tokenizer.')
|
|
|
|
|
2025-01-19 17:59:56 +08:00
|
|
|
parser.add_argument("--served-model-name",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="The model name used in the API. "
|
|
|
|
"If not specified, the model name will be the "
|
|
|
|
"same as the ``--model`` argument. ")
|
|
|
|
|
2025-02-08 14:45:44 +05:30
|
|
|
parser.add_argument("--lora-modules",
|
|
|
|
nargs='+',
|
|
|
|
default=None,
|
|
|
|
help="A subset of LoRA module names passed in when "
|
|
|
|
"launching the server. For each request, the "
|
|
|
|
"script chooses a LoRA module at random.")
|
|
|
|
|
2023-06-14 19:55:38 -07:00
|
|
|
args = parser.parse_args()
|
2024-09-27 16:13:25 +08:00
|
|
|
main(args)
|