[Misc] Use public API in benchmark_throughput (#4300)

This commit is contained in:
zifeitong 2024-04-24 14:10:24 -07:00 committed by GitHub
parent 2768884ac4
commit a395a638c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -103,25 +103,22 @@ def run_vllm(
)
# Add the requests to the engine.
prompts = []
sampling_params = []
for prompt, _, output_len in requests:
sampling_params = SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
)
# FIXME(woosuk): Do not use internal method.
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=sampling_params,
)
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
start = time.perf_counter()
# FIXME(woosuk): Do not use internal method.
llm._run_engine(use_tqdm=True)
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
return end - start