[misc] partial prefix & random input generation benchmark (#9929)
Signed-off-by: rickyx <rickyx@anyscale.com>
This commit is contained in:
parent
2298e69b5f
commit
90a6c759ca
@ -54,13 +54,30 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
|
||||
print(f"cost time {end_time - start_time}")
|
||||
|
||||
|
||||
def sample_requests(
|
||||
@dataclasses.dataclass
|
||||
class Request:
|
||||
prompt: str
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
|
||||
|
||||
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str:
|
||||
vocab = tokenizer.get_vocab()
|
||||
# Remove the special tokens.
|
||||
vocab = {
|
||||
k: v
|
||||
for k, v in vocab.items() if k not in tokenizer.all_special_ids
|
||||
}
|
||||
return random.choices(list(vocab.values()), k=length)
|
||||
|
||||
|
||||
def sample_requests_from_dataset(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_length_range: Tuple[int, int],
|
||||
fixed_output_len: Optional[int],
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
) -> List[Request]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
@ -77,31 +94,55 @@ def sample_requests(
|
||||
random.shuffle(dataset)
|
||||
|
||||
min_len, max_len = input_length_range
|
||||
assert min_len >= 0 and max_len >= min_len, "input_length_range too small"
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||
filtered_requests: List[Request] = []
|
||||
|
||||
for i in range(len(dataset)):
|
||||
if len(filtered_dataset) == num_requests:
|
||||
if len(filtered_requests) == num_requests:
|
||||
break
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
prompt_token_ids = tokenizer(dataset[i][0]).input_ids
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
completion = dataset[i][1]
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
if prompt_len < 4 or output_len < 4:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
output_len = (len(completion_token_ids)
|
||||
if fixed_output_len is None else fixed_output_len)
|
||||
if min_len <= prompt_len <= max_len:
|
||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||
filtered_requests.append(Request(prompt, prompt_len, output_len))
|
||||
|
||||
return filtered_dataset
|
||||
return filtered_requests
|
||||
|
||||
|
||||
def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
|
||||
def sample_requests_from_random(
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_length_range: Tuple[int, int],
|
||||
fixed_output_len: Optional[int],
|
||||
prefix_len: int,
|
||||
) -> List[Request]:
|
||||
|
||||
requests = []
|
||||
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
|
||||
min_len, max_len = input_length_range
|
||||
|
||||
for i in range(num_requests):
|
||||
unique_part_token_ids = sample_tokens(
|
||||
tokenizer,
|
||||
random.randint(min_len - prefix_len, max_len - prefix_len))
|
||||
prompt_token_ids = prefix_token_ids + unique_part_token_ids
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
prompt_len = len(prompt_token_ids)
|
||||
assert (min_len <= prompt_len <= max_len
|
||||
), f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
|
||||
requests.append(Request(prompt, prompt_len, fixed_output_len))
|
||||
return requests
|
||||
|
||||
|
||||
def repeat_and_sort_requests(requests: List[Request],
|
||||
repeat_count: int,
|
||||
sort: bool = False) -> List[str]:
|
||||
repeated_requests = requests * repeat_count
|
||||
@ -109,7 +150,7 @@ def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
|
||||
repeated_requests.sort(key=lambda x: x[1])
|
||||
else:
|
||||
random.shuffle(repeated_requests)
|
||||
return [req[0] for req in repeated_requests]
|
||||
return [req.prompt for req in repeated_requests]
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -117,9 +158,12 @@ def main(args):
|
||||
input_length_range = tuple(map(int, args.input_length_range.split(':')))
|
||||
random.seed(args.seed)
|
||||
if args.dataset_path is not None:
|
||||
print(f"Start to sample {args.num_prompts} prompts"
|
||||
if args.prefix_len > 0:
|
||||
raise ValueError("prefix-len is not supported when "
|
||||
"dataset-path is provided.")
|
||||
print(f"Start to sample {args.num_prompts} prompts "
|
||||
f"from {args.dataset_path}")
|
||||
filtered_datasets = sample_requests(
|
||||
filtered_requests = sample_requests_from_dataset(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
@ -127,9 +171,22 @@ def main(args):
|
||||
fixed_output_len=args.output_len,
|
||||
)
|
||||
else:
|
||||
prompt_len = len(tokenizer(PROMPT).input_ids)
|
||||
filtered_datasets = [(PROMPT, prompt_len, args.output_len)
|
||||
] * args.num_prompts
|
||||
print(f"Start to sample {args.num_prompts} prompts from random")
|
||||
filtered_requests = sample_requests_from_random(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
input_length_range=input_length_range,
|
||||
fixed_output_len=args.output_len,
|
||||
prefix_len=args.prefix_len,
|
||||
)
|
||||
|
||||
# Print some helpful stats of the requests.
|
||||
print(f"Sampled {len(filtered_requests)} requests.")
|
||||
prompt_lens = [req.prompt_len for req in filtered_requests]
|
||||
print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}")
|
||||
print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}")
|
||||
print(f"Min Prompt Length: {min(prompt_lens)}")
|
||||
print(f"Max Prompt Length: {max(prompt_lens)}")
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
@ -137,8 +194,8 @@ def main(args):
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||
|
||||
print("Testing filtered datasets")
|
||||
prompts = repeat_and_sort_requests(filtered_datasets,
|
||||
print("Testing filtered requests")
|
||||
prompts = repeat_and_sort_requests(filtered_requests,
|
||||
repeat_count=args.repeat_count,
|
||||
sort=args.sort)
|
||||
|
||||
@ -161,20 +218,29 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--output-len', type=int, default=10)
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=1,
|
||||
required=True,
|
||||
help="Number of the prompts sampled from dataset")
|
||||
parser.add_argument('--repeat-count',
|
||||
type=int,
|
||||
default=100,
|
||||
default=1,
|
||||
help='Number of times to repeat each prompt')
|
||||
parser.add_argument('--sort',
|
||||
action='store_true',
|
||||
help='Sort prompts by input length')
|
||||
parser.add_argument('--input-length-range',
|
||||
type=str,
|
||||
default='128:256',
|
||||
required=True,
|
||||
help='Range of input lengths for sampling prompts,'
|
||||
'specified as "min:max" (e.g., "128:256").')
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Specifies the length of a common prefix to be "
|
||||
"added to the input prompt. The input-length-range will "
|
||||
"subtract this length when filtering prompts. Only used "
|
||||
"when dataset-path is not provided.",
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
Loading…
x
Reference in New Issue
Block a user