diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 6054df43..2c2d69da 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -27,7 +27,7 @@ import time import warnings from dataclasses import dataclass from datetime import datetime -from typing import AsyncGenerator, List, Tuple +from typing import AsyncGenerator, List, Optional, Tuple import numpy as np from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, @@ -58,7 +58,11 @@ def sample_sharegpt_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) @@ -68,38 +72,32 @@ def sample_sharegpt_requests( dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset] - # some of these will be filtered out, so sample more than we need - sampled_indices = random.sample(range(len(dataset)), - int(num_requests * 1.2)) - dataset = [dataset[i] for i in sampled_indices] + # Shuffle the dataset. + random.shuffle(dataset) - # Tokenize the prompts and completions. - prompts = [prompt for prompt, _ in dataset] - prompt_token_ids = tokenizer(prompts).input_ids - completions = [completion for _, completion in dataset] - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): - output_len = len(completion_token_ids[i]) - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) - - # Filter out too long sequences. + # Filter out sequences that are too long or too short filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len if prompt_len < 4 or output_len < 4: # Prune too short sequences. - # This is because TGI causes errors when the input or output length - # is too short. continue if prompt_len > 1024 or prompt_len + output_len > 2048: # Prune too long sequences. continue filtered_dataset.append((prompt, prompt_len, output_len)) - # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) - return sampled_requests + return filtered_dataset def sample_sonnet_requests( @@ -361,6 +359,7 @@ def main(args: argparse.Namespace): dataset_path=args.dataset, num_requests=args.num_prompts, tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, ) elif args.dataset_name == "sharegpt": @@ -368,6 +367,7 @@ def main(args: argparse.Namespace): dataset_path=args.dataset_path, num_requests=args.num_prompts, tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, ) elif args.dataset_name == "sonnet": @@ -524,6 +524,12 @@ if __name__ == "__main__": default=1000, help="Number of prompts to process.", ) + parser.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.") parser.add_argument( "--sonnet-input-len", type=int,