[Bugfix] Generate exactly input_len tokens in benchmark_throughput (#9592)
This commit is contained in:
parent
208cb34c81
commit
65050a40e6
@ -233,7 +233,16 @@ def main(args: argparse.Namespace):
|
|||||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||||
if args.dataset is None:
|
if args.dataset is None:
|
||||||
# Synthesize a prompt with the given input length.
|
# Synthesize a prompt with the given input length.
|
||||||
prompt = "hi" * (args.input_len - 1)
|
# As tokenizer may add additional tokens like BOS, we need to try
|
||||||
|
# different lengths to get the desired input length.
|
||||||
|
for i in range(-10, 10):
|
||||||
|
prompt = "hi " * (args.input_len + i)
|
||||||
|
tokenized_prompt = tokenizer(prompt).input_ids
|
||||||
|
if len(tokenized_prompt) == args.input_len:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to synthesize a prompt with {args.input_len} tokens.")
|
||||||
requests = [(prompt, args.input_len, args.output_len)
|
requests = [(prompt, args.input_len, args.output_len)
|
||||||
for _ in range(args.num_prompts)]
|
for _ in range(args.num_prompts)]
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user