[Benchmark] Support sample from HF datasets and image input for benchmark_serving (#8495)
This commit is contained in:
parent
cbdb252259
commit
1b6de8352b
@ -25,6 +25,7 @@ class RequestFuncInput:
|
|||||||
best_of: int = 1
|
best_of: int = 1
|
||||||
use_beam_search: bool = False
|
use_beam_search: bool = False
|
||||||
logprobs: Optional[int] = None
|
logprobs: Optional[int] = None
|
||||||
|
multi_modal_content: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -312,12 +313,15 @@ async def async_request_openai_chat_completions(
|
|||||||
|
|
||||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
assert not request_func_input.use_beam_search
|
assert not request_func_input.use_beam_search
|
||||||
|
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||||
|
if request_func_input.multi_modal_content:
|
||||||
|
content.append(request_func_input.multi_modal_content)
|
||||||
payload = {
|
payload = {
|
||||||
"model": request_func_input.model,
|
"model": request_func_input.model,
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": request_func_input.prompt,
|
"content": content
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
|
@ -24,6 +24,8 @@ On the client side, run:
|
|||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@ -31,11 +33,13 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
||||||
RequestFuncOutput)
|
RequestFuncOutput)
|
||||||
|
from datasets import load_dataset
|
||||||
|
from PIL.Image import Image
|
||||||
from tqdm.asyncio import tqdm
|
from tqdm.asyncio import tqdm
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
@ -84,7 +88,7 @@ def sample_sharegpt_requests(
|
|||||||
num_requests: int,
|
num_requests: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
fixed_output_len: Optional[int] = None,
|
fixed_output_len: Optional[int] = None,
|
||||||
) -> List[Tuple[str, int, int]]:
|
) -> List[Tuple[str, int, int, None]]:
|
||||||
if fixed_output_len is not None and fixed_output_len < 4:
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
raise ValueError("output_len too small")
|
raise ValueError("output_len too small")
|
||||||
# Load the dataset.
|
# Load the dataset.
|
||||||
@ -119,7 +123,7 @@ def sample_sharegpt_requests(
|
|||||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||||
# Prune too long sequences.
|
# Prune too long sequences.
|
||||||
continue
|
continue
|
||||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
filtered_dataset.append((prompt, prompt_len, output_len, None))
|
||||||
|
|
||||||
return filtered_dataset
|
return filtered_dataset
|
||||||
|
|
||||||
@ -131,7 +135,7 @@ def sample_sonnet_requests(
|
|||||||
output_len: int,
|
output_len: int,
|
||||||
prefix_len: int,
|
prefix_len: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> List[Tuple[str, str, int, int]]:
|
) -> List[Tuple[str, str, int, int, None]]:
|
||||||
assert (
|
assert (
|
||||||
input_len > prefix_len
|
input_len > prefix_len
|
||||||
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
|
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
|
||||||
@ -189,7 +193,65 @@ def sample_sonnet_requests(
|
|||||||
message, add_generation_prompt=True, tokenize=False)
|
message, add_generation_prompt=True, tokenize=False)
|
||||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||||
sampled_requests.append(
|
sampled_requests.append(
|
||||||
(prompt, prompt_formatted, prompt_len, output_len))
|
(prompt, prompt_formatted, prompt_len, output_len, None))
|
||||||
|
|
||||||
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
|
def sample_hf_requests(
|
||||||
|
dataset_path: str,
|
||||||
|
dataset_subset: str,
|
||||||
|
dataset_split: str,
|
||||||
|
num_requests: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
fixed_output_len: Optional[int] = None,
|
||||||
|
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
|
||||||
|
dataset = load_dataset(dataset_path,
|
||||||
|
name=dataset_subset,
|
||||||
|
split=dataset_split,
|
||||||
|
streaming=True)
|
||||||
|
assert "conversations" in dataset.features, (
|
||||||
|
"HF Dataset must have 'conversations' column.")
|
||||||
|
filtered_dataset = dataset.shuffle().filter(
|
||||||
|
lambda x: len(x["conversations"]) >= 2)
|
||||||
|
sampled_requests: List[Tuple[str, int, int, Dict[str,
|
||||||
|
Collection[str]]]] = []
|
||||||
|
for data in filtered_dataset:
|
||||||
|
if len(sampled_requests) == num_requests:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Tokenize the prompts and completions.
|
||||||
|
prompt = data["conversations"][0]["value"]
|
||||||
|
prompt_token_ids = tokenizer(prompt).input_ids
|
||||||
|
completion = data["conversations"][1]["value"]
|
||||||
|
completion_token_ids = tokenizer(completion).input_ids
|
||||||
|
prompt_len = len(prompt_token_ids)
|
||||||
|
output_len = len(completion_token_ids
|
||||||
|
) if fixed_output_len is None else fixed_output_len
|
||||||
|
if prompt_len < 4 or output_len < 4:
|
||||||
|
# Prune too short sequences.
|
||||||
|
continue
|
||||||
|
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||||
|
# Prune too long sequences.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "image" in data and isinstance(data["image"], Image):
|
||||||
|
image: Image = data["image"]
|
||||||
|
image = image.convert("RGB")
|
||||||
|
image_data = io.BytesIO()
|
||||||
|
image.save(image_data, format='JPEG')
|
||||||
|
image_base64 = base64.b64encode(
|
||||||
|
image_data.getvalue()).decode("utf-8")
|
||||||
|
mm_content = {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
mm_content = None
|
||||||
|
|
||||||
|
sampled_requests.append((prompt, prompt_len, output_len, mm_content))
|
||||||
|
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
@ -223,8 +285,8 @@ def sample_random_requests(
|
|||||||
[(offsets[i] + i + j) % tokenizer.vocab_size
|
[(offsets[i] + i + j) % tokenizer.vocab_size
|
||||||
for j in range(input_lens[i])])
|
for j in range(input_lens[i])])
|
||||||
|
|
||||||
input_requests.append(
|
input_requests.append((prompt, int(prefix_len + input_lens[i]),
|
||||||
(prompt, int(prefix_len + input_lens[i]), int(output_lens[i])))
|
int(output_lens[i]), None))
|
||||||
|
|
||||||
return input_requests
|
return input_requests
|
||||||
|
|
||||||
@ -343,7 +405,12 @@ async def benchmark(
|
|||||||
raise ValueError(f"Unknown backend: {backend}")
|
raise ValueError(f"Unknown backend: {backend}")
|
||||||
|
|
||||||
print("Starting initial single prompt test run...")
|
print("Starting initial single prompt test run...")
|
||||||
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
|
||||||
|
input_requests[0])
|
||||||
|
if backend != "openai-chat" and test_mm_content is not None:
|
||||||
|
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||||
|
raise ValueError(
|
||||||
|
"Multi-modal content is only supported on 'openai-chat' backend.")
|
||||||
test_input = RequestFuncInput(
|
test_input = RequestFuncInput(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
prompt=test_prompt,
|
prompt=test_prompt,
|
||||||
@ -353,6 +420,7 @@ async def benchmark(
|
|||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
use_beam_search=use_beam_search,
|
use_beam_search=use_beam_search,
|
||||||
|
multi_modal_content=test_mm_content,
|
||||||
)
|
)
|
||||||
test_output = await request_func(request_func_input=test_input)
|
test_output = await request_func(request_func_input=test_input)
|
||||||
if not test_output.success:
|
if not test_output.success:
|
||||||
@ -373,6 +441,7 @@ async def benchmark(
|
|||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
use_beam_search=use_beam_search,
|
use_beam_search=use_beam_search,
|
||||||
|
multi_modal_content=test_mm_content,
|
||||||
)
|
)
|
||||||
profile_output = await request_func(request_func_input=profile_input)
|
profile_output = await request_func(request_func_input=profile_input)
|
||||||
if profile_output.success:
|
if profile_output.success:
|
||||||
@ -385,7 +454,7 @@ async def benchmark(
|
|||||||
benchmark_start_time = time.perf_counter()
|
benchmark_start_time = time.perf_counter()
|
||||||
tasks: List[asyncio.Task] = []
|
tasks: List[asyncio.Task] = []
|
||||||
async for request in get_request(input_requests, request_rate):
|
async for request in get_request(input_requests, request_rate):
|
||||||
prompt, prompt_len, output_len = request
|
prompt, prompt_len, output_len, mm_content = request
|
||||||
request_func_input = RequestFuncInput(
|
request_func_input = RequestFuncInput(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -395,6 +464,7 @@ async def benchmark(
|
|||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
use_beam_search=use_beam_search,
|
use_beam_search=use_beam_search,
|
||||||
|
multi_modal_content=mm_content,
|
||||||
)
|
)
|
||||||
tasks.append(
|
tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
@ -575,6 +645,16 @@ def main(args: argparse.Namespace):
|
|||||||
for prompt, prompt_formatted, prompt_len,
|
for prompt, prompt_formatted, prompt_len,
|
||||||
output_len in input_requests]
|
output_len in input_requests]
|
||||||
|
|
||||||
|
elif args.dataset_name == "hf":
|
||||||
|
input_requests = sample_hf_requests(
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
dataset_subset=args.hf_subset,
|
||||||
|
dataset_split=args.hf_split,
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
fixed_output_len=args.hf_output_len,
|
||||||
|
)
|
||||||
|
|
||||||
elif args.dataset_name == "random":
|
elif args.dataset_name == "random":
|
||||||
input_requests = sample_random_requests(
|
input_requests = sample_random_requests(
|
||||||
prefix_len=args.random_prefix_len,
|
prefix_len=args.random_prefix_len,
|
||||||
@ -685,13 +765,14 @@ if __name__ == "__main__":
|
|||||||
"--dataset-name",
|
"--dataset-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="sharegpt",
|
default="sharegpt",
|
||||||
choices=["sharegpt", "sonnet", "random"],
|
choices=["sharegpt", "sonnet", "random", "hf"],
|
||||||
help="Name of the dataset to benchmark on.",
|
help="Name of the dataset to benchmark on.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--dataset-path",
|
parser.add_argument("--dataset-path",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to the dataset.")
|
help="Path to the sharegpt/sonnet dataset. "
|
||||||
|
"Or the huggingface dataset ID if using HF dataset.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
@ -718,26 +799,6 @@ if __name__ == "__main__":
|
|||||||
default=1000,
|
default=1000,
|
||||||
help="Number of prompts to process.",
|
help="Number of prompts to process.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--sharegpt-output-len",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Output length for each request. Overrides the output length "
|
|
||||||
"from the ShareGPT dataset.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--sonnet-input-len",
|
|
||||||
type=int,
|
|
||||||
default=550,
|
|
||||||
help=
|
|
||||||
"Number of input tokens per request, used only for sonnet dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sonnet-output-len",
|
|
||||||
type=int,
|
|
||||||
default=150,
|
|
||||||
help=
|
|
||||||
"Number of output tokens per request, used only for sonnet dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logprobs",
|
"--logprobs",
|
||||||
type=int,
|
type=int,
|
||||||
@ -748,42 +809,6 @@ if __name__ == "__main__":
|
|||||||
"logprob is returned for each token; or (2) if beam search "
|
"logprob is returned for each token; or (2) if beam search "
|
||||||
"is enabled 1 logprob per token is computed"),
|
"is enabled 1 logprob per token is computed"),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--sonnet-prefix-len",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help=
|
|
||||||
"Number of prefix tokens per request, used only for sonnet dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--random-input-len",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help=
|
|
||||||
"Number of input tokens per request, used only for random sampling.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--random-output-len",
|
|
||||||
type=int,
|
|
||||||
default=128,
|
|
||||||
help=
|
|
||||||
"Number of output tokens per request, used only for random sampling.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--random-range-ratio",
|
|
||||||
type=float,
|
|
||||||
default=1.0,
|
|
||||||
help="Range of sampled ratio of input/output length, "
|
|
||||||
"used only for random sampling.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--random-prefix-len",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Number of fixed prefix tokens before random "
|
|
||||||
" context. The length range of context in a random "
|
|
||||||
" request is [random-prefix-len, "
|
|
||||||
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--request-rate",
|
"--request-rate",
|
||||||
type=float,
|
type=float,
|
||||||
@ -857,5 +882,85 @@ if __name__ == "__main__":
|
|||||||
"Use \"--percentile-metrics\" to select metrics.",
|
"Use \"--percentile-metrics\" to select metrics.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# group for dataset specific arguments
|
||||||
|
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
||||||
|
sonnet_group.add_argument(
|
||||||
|
"--sonnet-input-len",
|
||||||
|
type=int,
|
||||||
|
default=550,
|
||||||
|
help=
|
||||||
|
"Number of input tokens per request, used only for sonnet dataset.",
|
||||||
|
)
|
||||||
|
sonnet_group.add_argument(
|
||||||
|
"--sonnet-output-len",
|
||||||
|
type=int,
|
||||||
|
default=150,
|
||||||
|
help=
|
||||||
|
"Number of output tokens per request, used only for sonnet dataset.",
|
||||||
|
)
|
||||||
|
sonnet_group.add_argument(
|
||||||
|
"--sonnet-prefix-len",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help=
|
||||||
|
"Number of prefix tokens per request, used only for sonnet dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
|
||||||
|
sharegpt_group.add_argument(
|
||||||
|
"--sharegpt-output-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Output length for each request. Overrides the output length "
|
||||||
|
"from the ShareGPT dataset.")
|
||||||
|
|
||||||
|
random_group = parser.add_argument_group("random dataset options")
|
||||||
|
random_group.add_argument(
|
||||||
|
"--random-input-len",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help=
|
||||||
|
"Number of input tokens per request, used only for random sampling.",
|
||||||
|
)
|
||||||
|
random_group.add_argument(
|
||||||
|
"--random-output-len",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help=
|
||||||
|
"Number of output tokens per request, used only for random sampling.",
|
||||||
|
)
|
||||||
|
random_group.add_argument(
|
||||||
|
"--random-range-ratio",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Range of sampled ratio of input/output length, "
|
||||||
|
"used only for random sampling.",
|
||||||
|
)
|
||||||
|
random_group.add_argument(
|
||||||
|
"--random-prefix-len",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Number of fixed prefix tokens before random "
|
||||||
|
" context. The length range of context in a random "
|
||||||
|
" request is [random-prefix-len, "
|
||||||
|
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
||||||
|
|
||||||
|
hf_group = parser.add_argument_group("hf dataset options")
|
||||||
|
hf_group.add_argument("--hf-subset",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Subset of the HF dataset.")
|
||||||
|
hf_group.add_argument("--hf-split",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Split of the HF dataset.")
|
||||||
|
hf_group.add_argument(
|
||||||
|
"--hf-output-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Output length for each request. Overrides the output lengths "
|
||||||
|
"from the sampled HF dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user