[Bugfix] Generate exactly input_len tokens in benchmark_throughput (#9592)

This commit is contained in:
Chen Zhang 2024-10-22 17:45:35 -07:00 committed by GitHub
parent 208cb34c81
commit 65050a40e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -233,7 +233,16 @@ def main(args: argparse.Namespace):
args.tokenizer, trust_remote_code=args.trust_remote_code) args.tokenizer, trust_remote_code=args.trust_remote_code)
if args.dataset is None: if args.dataset is None:
# Synthesize a prompt with the given input length. # Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1) # As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for i in range(-10, 10):
prompt = "hi " * (args.input_len + i)
tokenized_prompt = tokenizer(prompt).input_ids
if len(tokenized_prompt) == args.input_len:
break
else:
raise ValueError(
f"Failed to synthesize a prompt with {args.input_len} tokens.")
requests = [(prompt, args.input_len, args.output_len) requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)] for _ in range(args.num_prompts)]
else: else: