pass ignore_eos parameter to all benchmark_serving calls (#9349)

This commit is contained in:
Grace Ho 2024-10-15 13:30:44 -07:00 committed by GitHub
parent e9d517f276
commit 5d264f4ab8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -431,8 +431,7 @@ async def benchmark(
if profile: if profile:
print("Starting profiler...") print("Starting profiler...")
profile_input = RequestFuncInput( profile_input = RequestFuncInput(model=model_id,
model=model_id,
prompt=test_prompt, prompt=test_prompt,
api_url=base_url + "/start_profile", api_url=base_url + "/start_profile",
prompt_len=test_prompt_len, prompt_len=test_prompt_len,
@ -440,7 +439,7 @@ async def benchmark(
logprobs=logprobs, logprobs=logprobs,
best_of=best_of, best_of=best_of,
multi_modal_content=test_mm_content, multi_modal_content=test_mm_content,
) ignore_eos=ignore_eos)
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(request_func_input=profile_input)
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
@ -453,8 +452,7 @@ async def benchmark(
tasks: List[asyncio.Task] = [] tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate): async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len, mm_content = request prompt, prompt_len, output_len, mm_content = request
request_func_input = RequestFuncInput( request_func_input = RequestFuncInput(model=model_id,
model=model_id,
prompt=prompt, prompt=prompt,
api_url=api_url, api_url=api_url,
prompt_len=prompt_len, prompt_len=prompt_len,
@ -462,7 +460,7 @@ async def benchmark(
logprobs=logprobs, logprobs=logprobs,
best_of=best_of, best_of=best_of,
multi_modal_content=mm_content, multi_modal_content=mm_content,
) ignore_eos=ignore_eos)
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
request_func(request_func_input=request_func_input, request_func(request_func_input=request_func_input,