[Feature] Update benchmark_throughput.py to support image input (#9851)
Signed-off-by: Linkun Chen <github+anyscale@lkchen.net> Co-authored-by: Linkun Chen <github+anyscale@lkchen.net>
This commit is contained in:
parent
a53046b16f
commit
d2e80332a7
@ -6,3 +6,14 @@ You can download the dataset by running:
|
|||||||
```bash
|
```bash
|
||||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Downloading the ShareGPT4V dataset
|
||||||
|
|
||||||
|
The json file refers to several image datasets (coco, llava, etc.). The benchmark scripts
|
||||||
|
will ignore a datapoint if the referred image is missing.
|
||||||
|
```bash
|
||||||
|
wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json
|
||||||
|
mkdir coco -p
|
||||||
|
wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip
|
||||||
|
unzip coco/train2017.zip -d coco/
|
||||||
|
```
|
||||||
|
@ -8,6 +8,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvloop
|
import uvloop
|
||||||
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||||
PreTrainedTokenizerBase)
|
PreTrainedTokenizerBase)
|
||||||
@ -38,12 +39,33 @@ class SampleRequest:
|
|||||||
multi_modal_data: Optional[MultiModalDataDict] = None
|
multi_modal_data: Optional[MultiModalDataDict] = None
|
||||||
|
|
||||||
|
|
||||||
def sample_requests(
|
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
|
||||||
dataset_path: str,
|
"""Prepend and append special tokens around the question to form a prompt.
|
||||||
num_requests: int,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
Args:
|
||||||
fixed_output_len: Optional[int],
|
question: The input question text to wrap with special tokens
|
||||||
) -> List[SampleRequest]:
|
model: The name of the model being used, to determine which special
|
||||||
|
tokens to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The formatted prompt string with appropriate special tokens for the
|
||||||
|
model
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If an unsupported model name is provided
|
||||||
|
"""
|
||||||
|
model = model.lower()
|
||||||
|
if "pixtral" in model:
|
||||||
|
return f"<s>[INST]{question}\n[IMG][/INST]"
|
||||||
|
raise ValueError(f"Unsupported model {model}")
|
||||||
|
|
||||||
|
|
||||||
|
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||||
|
args: argparse.Namespace) -> List[SampleRequest]:
|
||||||
|
dataset_path: str = args.dataset
|
||||||
|
num_requests: int = args.num_prompts
|
||||||
|
fixed_output_len: Optional[int] = args.output_len
|
||||||
|
model: str = args.model
|
||||||
if fixed_output_len is not None and fixed_output_len < 4:
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
raise ValueError("output_len too small")
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
@ -52,23 +74,36 @@ def sample_requests(
|
|||||||
dataset = json.load(f)
|
dataset = json.load(f)
|
||||||
# Filter out the conversations with less than 2 turns.
|
# Filter out the conversations with less than 2 turns.
|
||||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
# Only keep the first two turns of each conversation.
|
|
||||||
dataset = [(data["conversations"][0]["value"],
|
|
||||||
data["conversations"][1]["value"]) for data in dataset]
|
|
||||||
|
|
||||||
# Shuffle the dataset.
|
# Shuffle the dataset.
|
||||||
random.shuffle(dataset)
|
random.shuffle(dataset)
|
||||||
|
|
||||||
# Filter out sequences that are too long or too short
|
# Filter out sequences that are too long or too short
|
||||||
filtered_dataset: List[SampleRequest] = []
|
filtered_dataset: List[SampleRequest] = []
|
||||||
for i in range(len(dataset)):
|
for data in dataset:
|
||||||
if len(filtered_dataset) == num_requests:
|
if len(filtered_dataset) == num_requests:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Only keep the first two turns of each conversation.
|
||||||
|
prompt = data["conversations"][0]["value"]
|
||||||
|
completion = data["conversations"][1]["value"]
|
||||||
|
|
||||||
|
multi_modal_data: Optional[MultiModalDataDict] = None
|
||||||
|
if "image" in data:
|
||||||
|
multi_modal_data = multi_modal_data or {}
|
||||||
|
image_path = data["image"]
|
||||||
|
# TODO(vllm-project/vllm/issues/9778): Support multiple images.
|
||||||
|
assert isinstance(image_path,
|
||||||
|
str), "Only support single image input"
|
||||||
|
try:
|
||||||
|
multi_modal_data["image"] = Image.open(image_path).convert(
|
||||||
|
"RGB")
|
||||||
|
except FileNotFoundError:
|
||||||
|
# Ignore datapoint where asset is missing
|
||||||
|
continue
|
||||||
|
prompt = _get_prompt_for_image_model(question=prompt, model=model)
|
||||||
|
|
||||||
# Tokenize the prompts and completions.
|
# Tokenize the prompts and completions.
|
||||||
prompt = dataset[i][0]
|
|
||||||
prompt_token_ids = tokenizer(prompt).input_ids
|
prompt_token_ids = tokenizer(prompt).input_ids
|
||||||
completion = dataset[i][1]
|
|
||||||
completion_token_ids = tokenizer(completion).input_ids
|
completion_token_ids = tokenizer(completion).input_ids
|
||||||
prompt_len = len(prompt_token_ids)
|
prompt_len = len(prompt_token_ids)
|
||||||
output_len = len(completion_token_ids
|
output_len = len(completion_token_ids
|
||||||
@ -82,7 +117,8 @@ def sample_requests(
|
|||||||
filtered_dataset.append(
|
filtered_dataset.append(
|
||||||
SampleRequest(prompt=prompt,
|
SampleRequest(prompt=prompt,
|
||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
expected_output_len=output_len))
|
expected_output_len=output_len,
|
||||||
|
multi_modal_data=multi_modal_data))
|
||||||
|
|
||||||
return filtered_dataset
|
return filtered_dataset
|
||||||
|
|
||||||
@ -99,7 +135,9 @@ def run_vllm(
|
|||||||
prompts: List[TextPrompt] = []
|
prompts: List[TextPrompt] = []
|
||||||
sampling_params: List[SamplingParams] = []
|
sampling_params: List[SamplingParams] = []
|
||||||
for request in requests:
|
for request in requests:
|
||||||
prompts.append(TextPrompt(prompt=request.prompt))
|
prompts.append(
|
||||||
|
TextPrompt(prompt=request.prompt,
|
||||||
|
multi_modal_data=request.multi_modal_data))
|
||||||
sampling_params.append(
|
sampling_params.append(
|
||||||
SamplingParams(
|
SamplingParams(
|
||||||
n=n,
|
n=n,
|
||||||
@ -148,7 +186,9 @@ async def run_vllm_async(
|
|||||||
prompts: List[TextPrompt] = []
|
prompts: List[TextPrompt] = []
|
||||||
sampling_params: List[SamplingParams] = []
|
sampling_params: List[SamplingParams] = []
|
||||||
for request in requests:
|
for request in requests:
|
||||||
prompts.append(TextPrompt(prompt=request.prompt))
|
prompts.append(
|
||||||
|
TextPrompt(prompt=request.prompt,
|
||||||
|
multi_modal_data=request.multi_modal_data))
|
||||||
sampling_params.append(
|
sampling_params.append(
|
||||||
SamplingParams(
|
SamplingParams(
|
||||||
n=n,
|
n=n,
|
||||||
@ -272,9 +312,10 @@ def main(args: argparse.Namespace):
|
|||||||
for _ in range(args.num_prompts)
|
for _ in range(args.num_prompts)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
requests = sample_requests(tokenizer, args)
|
||||||
args.output_len)
|
|
||||||
|
|
||||||
|
is_multi_modal = any(request.multi_modal_data is not None
|
||||||
|
for request in requests)
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
if args.async_engine:
|
if args.async_engine:
|
||||||
elapsed_time = uvloop.run(
|
elapsed_time = uvloop.run(
|
||||||
@ -300,6 +341,11 @@ def main(args: argparse.Namespace):
|
|||||||
for request in requests)
|
for request in requests)
|
||||||
total_output_tokens = sum(request.expected_output_len
|
total_output_tokens = sum(request.expected_output_len
|
||||||
for request in requests)
|
for request in requests)
|
||||||
|
if is_multi_modal:
|
||||||
|
print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
|
||||||
|
"following metrics are not accurate because image tokens are not"
|
||||||
|
" counted. See vllm-project/vllm/issues/9778 for details.")
|
||||||
|
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
|
||||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user