[Misc] Update ShareGPT Dataset Sampling in Serving Benchmark (#4279)
This commit is contained in:
parent
3cd9b5bb2d
commit
7923dcad12
@ -27,7 +27,7 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, List, Tuple
|
from typing import AsyncGenerator, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
||||||
@ -58,7 +58,11 @@ def sample_sharegpt_requests(
|
|||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
num_requests: int,
|
num_requests: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
fixed_output_len: Optional[int] = None,
|
||||||
) -> List[Tuple[str, int, int]]:
|
) -> List[Tuple[str, int, int]]:
|
||||||
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
# Load the dataset.
|
# Load the dataset.
|
||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
dataset = json.load(f)
|
dataset = json.load(f)
|
||||||
@ -68,38 +72,32 @@ def sample_sharegpt_requests(
|
|||||||
dataset = [(data["conversations"][0]["value"],
|
dataset = [(data["conversations"][0]["value"],
|
||||||
data["conversations"][1]["value"]) for data in dataset]
|
data["conversations"][1]["value"]) for data in dataset]
|
||||||
|
|
||||||
# some of these will be filtered out, so sample more than we need
|
# Shuffle the dataset.
|
||||||
sampled_indices = random.sample(range(len(dataset)),
|
random.shuffle(dataset)
|
||||||
int(num_requests * 1.2))
|
|
||||||
dataset = [dataset[i] for i in sampled_indices]
|
|
||||||
|
|
||||||
# Tokenize the prompts and completions.
|
# Filter out sequences that are too long or too short
|
||||||
prompts = [prompt for prompt, _ in dataset]
|
|
||||||
prompt_token_ids = tokenizer(prompts).input_ids
|
|
||||||
completions = [completion for _, completion in dataset]
|
|
||||||
completion_token_ids = tokenizer(completions).input_ids
|
|
||||||
tokenized_dataset = []
|
|
||||||
for i in range(len(dataset)):
|
|
||||||
output_len = len(completion_token_ids[i])
|
|
||||||
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
|
||||||
|
|
||||||
# Filter out too long sequences.
|
|
||||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||||
for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
for i in range(len(dataset)):
|
||||||
|
if len(filtered_dataset) == num_requests:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Tokenize the prompts and completions.
|
||||||
|
prompt = dataset[i][0]
|
||||||
|
prompt_token_ids = tokenizer(prompt).input_ids
|
||||||
|
completion = dataset[i][1]
|
||||||
|
completion_token_ids = tokenizer(completion).input_ids
|
||||||
prompt_len = len(prompt_token_ids)
|
prompt_len = len(prompt_token_ids)
|
||||||
|
output_len = len(completion_token_ids
|
||||||
|
) if fixed_output_len is None else fixed_output_len
|
||||||
if prompt_len < 4 or output_len < 4:
|
if prompt_len < 4 or output_len < 4:
|
||||||
# Prune too short sequences.
|
# Prune too short sequences.
|
||||||
# This is because TGI causes errors when the input or output length
|
|
||||||
# is too short.
|
|
||||||
continue
|
continue
|
||||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||||
# Prune too long sequences.
|
# Prune too long sequences.
|
||||||
continue
|
continue
|
||||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||||
|
|
||||||
# Sample the requests.
|
return filtered_dataset
|
||||||
sampled_requests = random.sample(filtered_dataset, num_requests)
|
|
||||||
return sampled_requests
|
|
||||||
|
|
||||||
|
|
||||||
def sample_sonnet_requests(
|
def sample_sonnet_requests(
|
||||||
@ -361,6 +359,7 @@ def main(args: argparse.Namespace):
|
|||||||
dataset_path=args.dataset,
|
dataset_path=args.dataset,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
fixed_output_len=args.sharegpt_output_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif args.dataset_name == "sharegpt":
|
elif args.dataset_name == "sharegpt":
|
||||||
@ -368,6 +367,7 @@ def main(args: argparse.Namespace):
|
|||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
fixed_output_len=args.sharegpt_output_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif args.dataset_name == "sonnet":
|
elif args.dataset_name == "sonnet":
|
||||||
@ -524,6 +524,12 @@ if __name__ == "__main__":
|
|||||||
default=1000,
|
default=1000,
|
||||||
help="Number of prompts to process.",
|
help="Number of prompts to process.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sharegpt-output-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Output length for each request. Overrides the output length "
|
||||||
|
"from the ShareGPT dataset.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sonnet-input-len",
|
"--sonnet-input-len",
|
||||||
type=int,
|
type=int,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user