add benchmark for fix length input and output (#5857)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
6206dcb29e
commit
333306a252
@ -77,7 +77,6 @@ def sample_sharegpt_requests(
|
|||||||
) -> List[Tuple[str, int, int]]:
|
) -> List[Tuple[str, int, int]]:
|
||||||
if fixed_output_len is not None and fixed_output_len < 4:
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
raise ValueError("output_len too small")
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
# Load the dataset.
|
# Load the dataset.
|
||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
dataset = json.load(f)
|
dataset = json.load(f)
|
||||||
@ -185,6 +184,31 @@ def sample_sonnet_requests(
|
|||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
|
def sample_random_requests(
|
||||||
|
input_len: int, output_len: int, num_prompts: int, range_ratio: float,
|
||||||
|
tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]:
|
||||||
|
|
||||||
|
input_lens = np.random.randint(
|
||||||
|
int(input_len * range_ratio),
|
||||||
|
input_len + 1,
|
||||||
|
size=num_prompts,
|
||||||
|
)
|
||||||
|
output_lens = np.random.randint(
|
||||||
|
int(output_len * range_ratio),
|
||||||
|
output_len + 1,
|
||||||
|
size=num_prompts,
|
||||||
|
)
|
||||||
|
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
||||||
|
input_requests = []
|
||||||
|
for i in range(args.num_prompts):
|
||||||
|
prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size
|
||||||
|
for j in range(input_lens[i])])
|
||||||
|
input_requests.append(
|
||||||
|
(prompt, int(input_lens[i]), int(output_lens[i])))
|
||||||
|
|
||||||
|
return input_requests
|
||||||
|
|
||||||
|
|
||||||
async def get_request(
|
async def get_request(
|
||||||
input_requests: List[Tuple[str, int, int]],
|
input_requests: List[Tuple[str, int, int]],
|
||||||
request_rate: float,
|
request_rate: float,
|
||||||
@ -196,6 +220,7 @@ async def get_request(
|
|||||||
if request_rate == float("inf"):
|
if request_rate == float("inf"):
|
||||||
# If the request rate is infinity, then we don't need to wait.
|
# If the request rate is infinity, then we don't need to wait.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Sample the request interval from the exponential distribution.
|
# Sample the request interval from the exponential distribution.
|
||||||
interval = np.random.exponential(1.0 / request_rate)
|
interval = np.random.exponential(1.0 / request_rate)
|
||||||
# The next request will be sent after the interval.
|
# The next request will be sent after the interval.
|
||||||
@ -219,7 +244,7 @@ def calculate_metrics(
|
|||||||
# We use the tokenizer to count the number of output tokens for all
|
# We use the tokenizer to count the number of output tokens for all
|
||||||
# serving backends instead of looking at len(outputs[i].itl) since
|
# serving backends instead of looking at len(outputs[i].itl) since
|
||||||
# multiple output tokens may be bundled together
|
# multiple output tokens may be bundled together
|
||||||
# Note: this may inflate the output token count slightly
|
# Note : this may inflate the output token count slightly
|
||||||
output_len = len(
|
output_len = len(
|
||||||
tokenizer(outputs[i].generated_text,
|
tokenizer(outputs[i].generated_text,
|
||||||
add_special_tokens=False).input_ids)
|
add_special_tokens=False).input_ids)
|
||||||
@ -456,6 +481,15 @@ def main(args: argparse.Namespace):
|
|||||||
for prompt, prompt_formatted, prompt_len,
|
for prompt, prompt_formatted, prompt_len,
|
||||||
output_len in input_requests]
|
output_len in input_requests]
|
||||||
|
|
||||||
|
elif args.dataset_name == "random":
|
||||||
|
input_requests = sample_random_requests(
|
||||||
|
input_len=args.input_len,
|
||||||
|
output_len=args.output_len,
|
||||||
|
num_prompts=args.num_prompts,
|
||||||
|
range_ratio=args.range_ratio,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||||
|
|
||||||
@ -549,7 +583,7 @@ if __name__ == "__main__":
|
|||||||
"--dataset-name",
|
"--dataset-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="sharegpt",
|
default="sharegpt",
|
||||||
choices=["sharegpt", "sonnet"],
|
choices=["sharegpt", "sonnet", "random"],
|
||||||
help="Name of the dataset to benchmark on.",
|
help="Name of the dataset to benchmark on.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--dataset-path",
|
parser.add_argument("--dataset-path",
|
||||||
@ -566,7 +600,7 @@ if __name__ == "__main__":
|
|||||||
"--tokenizer",
|
"--tokenizer",
|
||||||
type=str,
|
type=str,
|
||||||
help=
|
help=
|
||||||
"Name or path of the tokenizer, if not using the default tokenizer.",
|
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--best-of",
|
"--best-of",
|
||||||
@ -609,6 +643,27 @@ if __name__ == "__main__":
|
|||||||
help=
|
help=
|
||||||
"Number of prefix tokens per request, used only for sonnet dataset.",
|
"Number of prefix tokens per request, used only for sonnet dataset.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--random-input-len",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help=
|
||||||
|
"Number of input tokens per request, used only for random sampling.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--random-output-len",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help=
|
||||||
|
"Number of output tokens per request, used only for random sampling.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--random-range-ratio",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Range of sampled ratio of input/output length, "
|
||||||
|
"used only for random sampling.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--request-rate",
|
"--request-rate",
|
||||||
type=float,
|
type=float,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user