[Feature] Update benchmark_throughput.py to support image input (#9851)

Signed-off-by: Linkun Chen <github+anyscale@lkchen.net>
Co-authored-by: Linkun Chen <github+anyscale@lkchen.net>
This commit is contained in:
lkchen 2024-11-05 11:30:02 -08:00 committed by GitHub
parent a53046b16f
commit d2e80332a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 75 additions and 18 deletions

View File

@ -6,3 +6,14 @@ You can download the dataset by running:
```bash ```bash
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
``` ```
## Downloading the ShareGPT4V dataset
The json file refers to several image datasets (coco, llava, etc.). The benchmark scripts
will ignore a datapoint if the referred image is missing.
```bash
wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json
mkdir coco -p
wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip
unzip coco/train2017.zip -d coco/
```

View File

@ -8,6 +8,7 @@ from typing import List, Optional
import torch import torch
import uvloop import uvloop
from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
@ -38,12 +39,33 @@ class SampleRequest:
multi_modal_data: Optional[MultiModalDataDict] = None multi_modal_data: Optional[MultiModalDataDict] = None
def sample_requests( def _get_prompt_for_image_model(question: str, *, model: str) -> str:
dataset_path: str, """Prepend and append special tokens around the question to form a prompt.
num_requests: int,
tokenizer: PreTrainedTokenizerBase, Args:
fixed_output_len: Optional[int], question: The input question text to wrap with special tokens
) -> List[SampleRequest]: model: The name of the model being used, to determine which special
tokens to add
Returns:
The formatted prompt string with appropriate special tokens for the
model
Raises:
ValueError: If an unsupported model name is provided
"""
model = model.lower()
if "pixtral" in model:
return f"<s>[INST]{question}\n[IMG][/INST]"
raise ValueError(f"Unsupported model {model}")
def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:
dataset_path: str = args.dataset
num_requests: int = args.num_prompts
fixed_output_len: Optional[int] = args.output_len
model: str = args.model
if fixed_output_len is not None and fixed_output_len < 4: if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small") raise ValueError("output_len too small")
@ -52,23 +74,36 @@ def sample_requests(
dataset = json.load(f) dataset = json.load(f)
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]
# Shuffle the dataset. # Shuffle the dataset.
random.shuffle(dataset) random.shuffle(dataset)
# Filter out sequences that are too long or too short # Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = [] filtered_dataset: List[SampleRequest] = []
for i in range(len(dataset)): for data in dataset:
if len(filtered_dataset) == num_requests: if len(filtered_dataset) == num_requests:
break break
# Only keep the first two turns of each conversation.
prompt = data["conversations"][0]["value"]
completion = data["conversations"][1]["value"]
multi_modal_data: Optional[MultiModalDataDict] = None
if "image" in data:
multi_modal_data = multi_modal_data or {}
image_path = data["image"]
# TODO(vllm-project/vllm/issues/9778): Support multiple images.
assert isinstance(image_path,
str), "Only support single image input"
try:
multi_modal_data["image"] = Image.open(image_path).convert(
"RGB")
except FileNotFoundError:
# Ignore datapoint where asset is missing
continue
prompt = _get_prompt_for_image_model(question=prompt, model=model)
# Tokenize the prompts and completions. # Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids output_len = len(completion_token_ids
@ -82,7 +117,8 @@ def sample_requests(
filtered_dataset.append( filtered_dataset.append(
SampleRequest(prompt=prompt, SampleRequest(prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len)) expected_output_len=output_len,
multi_modal_data=multi_modal_data))
return filtered_dataset return filtered_dataset
@ -99,7 +135,9 @@ def run_vllm(
prompts: List[TextPrompt] = [] prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = [] sampling_params: List[SamplingParams] = []
for request in requests: for request in requests:
prompts.append(TextPrompt(prompt=request.prompt)) prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
@ -148,7 +186,9 @@ async def run_vllm_async(
prompts: List[TextPrompt] = [] prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = [] sampling_params: List[SamplingParams] = []
for request in requests: for request in requests:
prompts.append(TextPrompt(prompt=request.prompt)) prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
@ -272,9 +312,10 @@ def main(args: argparse.Namespace):
for _ in range(args.num_prompts) for _ in range(args.num_prompts)
] ]
else: else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer, requests = sample_requests(tokenizer, args)
args.output_len)
is_multi_modal = any(request.multi_modal_data is not None
for request in requests)
if args.backend == "vllm": if args.backend == "vllm":
if args.async_engine: if args.async_engine:
elapsed_time = uvloop.run( elapsed_time = uvloop.run(
@ -300,6 +341,11 @@ def main(args: argparse.Namespace):
for request in requests) for request in requests)
total_output_tokens = sum(request.expected_output_len total_output_tokens = sum(request.expected_output_len
for request in requests) for request in requests)
if is_multi_modal:
print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
"following metrics are not accurate because image tokens are not"
" counted. See vllm-project/vllm/issues/9778 for details.")
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s") f"{total_output_tokens / elapsed_time:.2f} output tokens/s")