[Benchmark] Support sample from HF datasets and image input for benchmark_serving (#8495)
This commit is contained in:
parent
cbdb252259
commit
1b6de8352b
@ -25,6 +25,7 @@ class RequestFuncInput:
|
||||
best_of: int = 1
|
||||
use_beam_search: bool = False
|
||||
logprobs: Optional[int] = None
|
||||
multi_modal_content: Optional[dict] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -312,12 +313,15 @@ async def async_request_openai_chat_completions(
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
if request_func_input.multi_modal_content:
|
||||
content.append(request_func_input.multi_modal_content)
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": request_func_input.prompt,
|
||||
"content": content
|
||||
},
|
||||
],
|
||||
"temperature": 0.0,
|
||||
|
@ -24,6 +24,8 @@ On the client side, run:
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
@ -31,11 +33,13 @@ import time
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
||||
RequestFuncOutput)
|
||||
from datasets import load_dataset
|
||||
from PIL.Image import Image
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
@ -84,7 +88,7 @@ def sample_sharegpt_requests(
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
) -> List[Tuple[str, int, int, None]]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
# Load the dataset.
|
||||
@ -119,7 +123,7 @@ def sample_sharegpt_requests(
|
||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||
filtered_dataset.append((prompt, prompt_len, output_len, None))
|
||||
|
||||
return filtered_dataset
|
||||
|
||||
@ -131,7 +135,7 @@ def sample_sonnet_requests(
|
||||
output_len: int,
|
||||
prefix_len: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[Tuple[str, str, int, int]]:
|
||||
) -> List[Tuple[str, str, int, int, None]]:
|
||||
assert (
|
||||
input_len > prefix_len
|
||||
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
|
||||
@ -189,7 +193,65 @@ def sample_sonnet_requests(
|
||||
message, add_generation_prompt=True, tokenize=False)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
sampled_requests.append(
|
||||
(prompt, prompt_formatted, prompt_len, output_len))
|
||||
(prompt, prompt_formatted, prompt_len, output_len, None))
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def sample_hf_requests(
|
||||
dataset_path: str,
|
||||
dataset_subset: str,
|
||||
dataset_split: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
|
||||
dataset = load_dataset(dataset_path,
|
||||
name=dataset_subset,
|
||||
split=dataset_split,
|
||||
streaming=True)
|
||||
assert "conversations" in dataset.features, (
|
||||
"HF Dataset must have 'conversations' column.")
|
||||
filtered_dataset = dataset.shuffle().filter(
|
||||
lambda x: len(x["conversations"]) >= 2)
|
||||
sampled_requests: List[Tuple[str, int, int, Dict[str,
|
||||
Collection[str]]]] = []
|
||||
for data in filtered_dataset:
|
||||
if len(sampled_requests) == num_requests:
|
||||
break
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = data["conversations"][0]["value"]
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
completion = data["conversations"][1]["value"]
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
if prompt_len < 4 or output_len < 4:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
|
||||
if "image" in data and isinstance(data["image"], Image):
|
||||
image: Image = data["image"]
|
||||
image = image.convert("RGB")
|
||||
image_data = io.BytesIO()
|
||||
image.save(image_data, format='JPEG')
|
||||
image_base64 = base64.b64encode(
|
||||
image_data.getvalue()).decode("utf-8")
|
||||
mm_content = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
},
|
||||
}
|
||||
else:
|
||||
mm_content = None
|
||||
|
||||
sampled_requests.append((prompt, prompt_len, output_len, mm_content))
|
||||
|
||||
return sampled_requests
|
||||
|
||||
@ -223,8 +285,8 @@ def sample_random_requests(
|
||||
[(offsets[i] + i + j) % tokenizer.vocab_size
|
||||
for j in range(input_lens[i])])
|
||||
|
||||
input_requests.append(
|
||||
(prompt, int(prefix_len + input_lens[i]), int(output_lens[i])))
|
||||
input_requests.append((prompt, int(prefix_len + input_lens[i]),
|
||||
int(output_lens[i]), None))
|
||||
|
||||
return input_requests
|
||||
|
||||
@ -343,7 +405,12 @@ async def benchmark(
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
||||
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
|
||||
input_requests[0])
|
||||
if backend != "openai-chat" and test_mm_content is not None:
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' backend.")
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
@ -353,6 +420,7 @@ async def benchmark(
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
multi_modal_content=test_mm_content,
|
||||
)
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
if not test_output.success:
|
||||
@ -373,6 +441,7 @@ async def benchmark(
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
multi_modal_content=test_mm_content,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
@ -385,7 +454,7 @@ async def benchmark(
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: List[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate):
|
||||
prompt, prompt_len, output_len = request
|
||||
prompt, prompt_len, output_len, mm_content = request
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
@ -395,6 +464,7 @@ async def benchmark(
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
multi_modal_content=mm_content,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
@ -575,6 +645,16 @@ def main(args: argparse.Namespace):
|
||||
for prompt, prompt_formatted, prompt_len,
|
||||
output_len in input_requests]
|
||||
|
||||
elif args.dataset_name == "hf":
|
||||
input_requests = sample_hf_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
dataset_split=args.hf_split,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=args.hf_output_len,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "random":
|
||||
input_requests = sample_random_requests(
|
||||
prefix_len=args.random_prefix_len,
|
||||
@ -685,13 +765,14 @@ if __name__ == "__main__":
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="sharegpt",
|
||||
choices=["sharegpt", "sonnet", "random"],
|
||||
choices=["sharegpt", "sonnet", "random", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset.")
|
||||
help="Path to the sharegpt/sonnet dataset. "
|
||||
"Or the huggingface dataset ID if using HF dataset.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
@ -718,26 +799,6 @@ if __name__ == "__main__":
|
||||
default=1000,
|
||||
help="Number of prompts to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sharegpt-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the output length "
|
||||
"from the ShareGPT dataset.")
|
||||
parser.add_argument(
|
||||
"--sonnet-input-len",
|
||||
type=int,
|
||||
default=550,
|
||||
help=
|
||||
"Number of input tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sonnet-output-len",
|
||||
type=int,
|
||||
default=150,
|
||||
help=
|
||||
"Number of output tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logprobs",
|
||||
type=int,
|
||||
@ -748,42 +809,6 @@ if __name__ == "__main__":
|
||||
"logprob is returned for each token; or (2) if beam search "
|
||||
"is enabled 1 logprob per token is computed"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sonnet-prefix-len",
|
||||
type=int,
|
||||
default=200,
|
||||
help=
|
||||
"Number of prefix tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-input-len",
|
||||
type=int,
|
||||
default=1024,
|
||||
help=
|
||||
"Number of input tokens per request, used only for random sampling.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-output-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help=
|
||||
"Number of output tokens per request, used only for random sampling.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for random sampling.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of fixed prefix tokens before random "
|
||||
" context. The length range of context in a random "
|
||||
" request is [random-prefix-len, "
|
||||
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
||||
parser.add_argument(
|
||||
"--request-rate",
|
||||
type=float,
|
||||
@ -857,5 +882,85 @@ if __name__ == "__main__":
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
)
|
||||
|
||||
# group for dataset specific arguments
|
||||
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-input-len",
|
||||
type=int,
|
||||
default=550,
|
||||
help=
|
||||
"Number of input tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-output-len",
|
||||
type=int,
|
||||
default=150,
|
||||
help=
|
||||
"Number of output tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-prefix-len",
|
||||
type=int,
|
||||
default=200,
|
||||
help=
|
||||
"Number of prefix tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
|
||||
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
|
||||
sharegpt_group.add_argument(
|
||||
"--sharegpt-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the output length "
|
||||
"from the ShareGPT dataset.")
|
||||
|
||||
random_group = parser.add_argument_group("random dataset options")
|
||||
random_group.add_argument(
|
||||
"--random-input-len",
|
||||
type=int,
|
||||
default=1024,
|
||||
help=
|
||||
"Number of input tokens per request, used only for random sampling.",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-output-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help=
|
||||
"Number of output tokens per request, used only for random sampling.",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for random sampling.",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of fixed prefix tokens before random "
|
||||
" context. The length range of context in a random "
|
||||
" request is [random-prefix-len, "
|
||||
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
||||
|
||||
hf_group = parser.add_argument_group("hf dataset options")
|
||||
hf_group.add_argument("--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HF dataset.")
|
||||
hf_group.add_argument("--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HF dataset.")
|
||||
hf_group.add_argument(
|
||||
"--hf-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the output lengths "
|
||||
"from the sampled HF dataset.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user