[Misc] Use public API in benchmark_throughput (#4300)

This commit is contained in:
zifeitong 2024-04-24 14:10:24 -07:00 committed by GitHub
parent 2768884ac4
commit a395a638c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -103,25 +103,22 @@ def run_vllm(
) )
# Add the requests to the engine. # Add the requests to the engine.
prompts = []
sampling_params = []
for prompt, _, output_len in requests: for prompt, _, output_len in requests:
sampling_params = SamplingParams( prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n, n=n,
temperature=0.0 if use_beam_search else 1.0, temperature=0.0 if use_beam_search else 1.0,
top_p=1.0, top_p=1.0,
use_beam_search=use_beam_search, use_beam_search=use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
) ))
# FIXME(woosuk): Do not use internal method.
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=sampling_params,
)
start = time.perf_counter() start = time.perf_counter()
# FIXME(woosuk): Do not use internal method. llm.generate(prompts, sampling_params, use_tqdm=True)
llm._run_engine(use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
return end - start return end - start