[Benchmark] Support sample from HF datasets and image input for benchmark_serving (#8495)

This commit is contained in:
Isotr0py 2024-09-17 15:34:27 +08:00 committed by GitHub
parent cbdb252259
commit 1b6de8352b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 177 additions and 68 deletions

View File

@ -25,6 +25,7 @@ class RequestFuncInput:
best_of: int = 1
use_beam_search: bool = False
logprobs: Optional[int] = None
multi_modal_content: Optional[dict] = None
@dataclass
@ -312,12 +313,15 @@ async def async_request_openai_chat_completions(
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
payload = {
"model": request_func_input.model,
"messages": [
{
"role": "user",
"content": request_func_input.prompt,
"content": content
},
],
"temperature": 0.0,

View File

@ -24,6 +24,8 @@ On the client side, run:
"""
import argparse
import asyncio
import base64
import io
import json
import os
import random
@ -31,11 +33,13 @@ import time
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple
import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
RequestFuncOutput)
from datasets import load_dataset
from PIL.Image import Image
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
@ -84,7 +88,7 @@ def sample_sharegpt_requests(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]:
) -> List[Tuple[str, int, int, None]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
@ -119,7 +123,7 @@ def sample_sharegpt_requests(
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
filtered_dataset.append((prompt, prompt_len, output_len, None))
return filtered_dataset
@ -131,7 +135,7 @@ def sample_sonnet_requests(
output_len: int,
prefix_len: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, str, int, int]]:
) -> List[Tuple[str, str, int, int, None]]:
assert (
input_len > prefix_len
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
@ -189,7 +193,65 @@ def sample_sonnet_requests(
message, add_generation_prompt=True, tokenize=False)
prompt_len = len(tokenizer(prompt_formatted).input_ids)
sampled_requests.append(
(prompt, prompt_formatted, prompt_len, output_len))
(prompt, prompt_formatted, prompt_len, output_len, None))
return sampled_requests
def sample_hf_requests(
dataset_path: str,
dataset_subset: str,
dataset_split: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
dataset = load_dataset(dataset_path,
name=dataset_subset,
split=dataset_split,
streaming=True)
assert "conversations" in dataset.features, (
"HF Dataset must have 'conversations' column.")
filtered_dataset = dataset.shuffle().filter(
lambda x: len(x["conversations"]) >= 2)
sampled_requests: List[Tuple[str, int, int, Dict[str,
Collection[str]]]] = []
for data in filtered_dataset:
if len(sampled_requests) == num_requests:
break
# Tokenize the prompts and completions.
prompt = data["conversations"][0]["value"]
prompt_token_ids = tokenizer(prompt).input_ids
completion = data["conversations"][1]["value"]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
if "image" in data and isinstance(data["image"], Image):
image: Image = data["image"]
image = image.convert("RGB")
image_data = io.BytesIO()
image.save(image_data, format='JPEG')
image_base64 = base64.b64encode(
image_data.getvalue()).decode("utf-8")
mm_content = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
}
else:
mm_content = None
sampled_requests.append((prompt, prompt_len, output_len, mm_content))
return sampled_requests
@ -223,8 +285,8 @@ def sample_random_requests(
[(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])])
input_requests.append(
(prompt, int(prefix_len + input_lens[i]), int(output_lens[i])))
input_requests.append((prompt, int(prefix_len + input_lens[i]),
int(output_lens[i]), None))
return input_requests
@ -343,7 +405,12 @@ async def benchmark(
raise ValueError(f"Unknown backend: {backend}")
print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len = input_requests[0]
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
input_requests[0])
if backend != "openai-chat" and test_mm_content is not None:
# multi-modal benchmark is only available on OpenAI Chat backend.
raise ValueError(
"Multi-modal content is only supported on 'openai-chat' backend.")
test_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
@ -353,6 +420,7 @@ async def benchmark(
logprobs=logprobs,
best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=test_mm_content,
)
test_output = await request_func(request_func_input=test_input)
if not test_output.success:
@ -373,6 +441,7 @@ async def benchmark(
logprobs=logprobs,
best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=test_mm_content,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
@ -385,7 +454,7 @@ async def benchmark(
benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
prompt, prompt_len, output_len, mm_content = request
request_func_input = RequestFuncInput(
model=model_id,
prompt=prompt,
@ -395,6 +464,7 @@ async def benchmark(
logprobs=logprobs,
best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=mm_content,
)
tasks.append(
asyncio.create_task(
@ -575,6 +645,16 @@ def main(args: argparse.Namespace):
for prompt, prompt_formatted, prompt_len,
output_len in input_requests]
elif args.dataset_name == "hf":
input_requests = sample_hf_requests(
dataset_path=args.dataset_path,
dataset_subset=args.hf_subset,
dataset_split=args.hf_split,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.hf_output_len,
)
elif args.dataset_name == "random":
input_requests = sample_random_requests(
prefix_len=args.random_prefix_len,
@ -685,13 +765,14 @@ if __name__ == "__main__":
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "sonnet", "random"],
choices=["sharegpt", "sonnet", "random", "hf"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument("--dataset-path",
type=str,
default=None,
help="Path to the dataset.")
help="Path to the sharegpt/sonnet dataset. "
"Or the huggingface dataset ID if using HF dataset.")
parser.add_argument(
"--model",
type=str,
@ -718,26 +799,6 @@ if __name__ == "__main__":
default=1000,
help="Number of prompts to process.",
)
parser.add_argument(
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.")
parser.add_argument(
"--sonnet-input-len",
type=int,
default=550,
help=
"Number of input tokens per request, used only for sonnet dataset.",
)
parser.add_argument(
"--sonnet-output-len",
type=int,
default=150,
help=
"Number of output tokens per request, used only for sonnet dataset.",
)
parser.add_argument(
"--logprobs",
type=int,
@ -748,42 +809,6 @@ if __name__ == "__main__":
"logprob is returned for each token; or (2) if beam search "
"is enabled 1 logprob per token is computed"),
)
parser.add_argument(
"--sonnet-prefix-len",
type=int,
default=200,
help=
"Number of prefix tokens per request, used only for sonnet dataset.",
)
parser.add_argument(
"--random-input-len",
type=int,
default=1024,
help=
"Number of input tokens per request, used only for random sampling.",
)
parser.add_argument(
"--random-output-len",
type=int,
default=128,
help=
"Number of output tokens per request, used only for random sampling.",
)
parser.add_argument(
"--random-range-ratio",
type=float,
default=1.0,
help="Range of sampled ratio of input/output length, "
"used only for random sampling.",
)
parser.add_argument(
"--random-prefix-len",
type=int,
default=0,
help="Number of fixed prefix tokens before random "
" context. The length range of context in a random "
" request is [random-prefix-len, "
" random-prefix-len + random-prefix-len * random-range-ratio).")
parser.add_argument(
"--request-rate",
type=float,
@ -857,5 +882,85 @@ if __name__ == "__main__":
"Use \"--percentile-metrics\" to select metrics.",
)
# group for dataset specific arguments
sonnet_group = parser.add_argument_group("sonnet dataset options")
sonnet_group.add_argument(
"--sonnet-input-len",
type=int,
default=550,
help=
"Number of input tokens per request, used only for sonnet dataset.",
)
sonnet_group.add_argument(
"--sonnet-output-len",
type=int,
default=150,
help=
"Number of output tokens per request, used only for sonnet dataset.",
)
sonnet_group.add_argument(
"--sonnet-prefix-len",
type=int,
default=200,
help=
"Number of prefix tokens per request, used only for sonnet dataset.",
)
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
sharegpt_group.add_argument(
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.")
random_group = parser.add_argument_group("random dataset options")
random_group.add_argument(
"--random-input-len",
type=int,
default=1024,
help=
"Number of input tokens per request, used only for random sampling.",
)
random_group.add_argument(
"--random-output-len",
type=int,
default=128,
help=
"Number of output tokens per request, used only for random sampling.",
)
random_group.add_argument(
"--random-range-ratio",
type=float,
default=1.0,
help="Range of sampled ratio of input/output length, "
"used only for random sampling.",
)
random_group.add_argument(
"--random-prefix-len",
type=int,
default=0,
help="Number of fixed prefix tokens before random "
" context. The length range of context in a random "
" request is [random-prefix-len, "
" random-prefix-len + random-prefix-len * random-range-ratio).")
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset",
type=str,
default=None,
help="Subset of the HF dataset.")
hf_group.add_argument("--hf-split",
type=str,
default=None,
help="Split of the HF dataset.")
hf_group.add_argument(
"--hf-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output lengths "
"from the sampled HF dataset.",
)
args = parser.parse_args()
main(args)