diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 6bb889d1..695d06e7 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -103,25 +103,22 @@ def run_vllm( ) # Add the requests to the engine. + prompts = [] + sampling_params = [] for prompt, _, output_len in requests: - sampling_params = SamplingParams( - n=n, - temperature=0.0 if use_beam_search else 1.0, - top_p=1.0, - use_beam_search=use_beam_search, - ignore_eos=True, - max_tokens=output_len, - ) - # FIXME(woosuk): Do not use internal method. - llm._add_request( - prompt=prompt, - prompt_token_ids=None, - sampling_params=sampling_params, - ) + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) start = time.perf_counter() - # FIXME(woosuk): Do not use internal method. - llm._run_engine(use_tqdm=True) + llm.generate(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() return end - start