[Benchmark] Update Vision Arena Dataset and HuggingFaceDataset Setup (#15748)
Signed-off-by: Jennifer Zhao <ai.jenniferzhao@gmail.com>
This commit is contained in:
parent
18ed3132d2
commit
effc5d24fa
@ -41,29 +41,33 @@ become available.
|
||||
<td><code>synthetic</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>HuggingFace</strong></td>
|
||||
<td style="text-align: center;">🟡</td>
|
||||
<td style="text-align: center;">🟡</td>
|
||||
<td>Specify your dataset path on HuggingFace</td>
|
||||
<td><strong>HuggingFace-VisionArena</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>lmarena-ai/VisionArena-Chat</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>VisionArena</strong></td>
|
||||
<td><strong>HuggingFace-InstructCoder</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>lmarena-ai/vision-arena-bench-v0.1</code> (a HuggingFace dataset)</td>
|
||||
<td><code>likaixin/InstructCoder</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>HuggingFace-Other</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>lmms-lab/LLaVA-OneVision-Data</code>, <code>Aeala/ShareGPT_Vicuna_unfiltered</code></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
✅: supported
|
||||
|
||||
🟡: Partial support
|
||||
|
||||
🚧: to be supported
|
||||
|
||||
🟡: Partial support. Currently, HuggingFaceDataset only supports dataset formats
|
||||
similar to `lmms-lab/LLaVA-OneVision-Data` and `Aeala/ShareGPT_Vicuna_unfiltered`.
|
||||
If you need support for other dataset formats, please consider contributing.
|
||||
|
||||
**Note**: VisionArena’s `dataset-name` should be set to `hf`
|
||||
**Note**: HuggingFace dataset's `dataset-name` should be set to `hf`
|
||||
|
||||
---
|
||||
## Example - Online Benchmark
|
||||
@ -71,8 +75,7 @@ If you need support for other dataset formats, please consider contributing.
|
||||
First start serving your model
|
||||
|
||||
```bash
|
||||
MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
vllm serve ${MODEL_NAME} --disable-log-requests
|
||||
vllm serve NousResearch/Hermes-3-Llama-3.1-8B --disable-log-requests
|
||||
```
|
||||
|
||||
Then run the benchmarking script
|
||||
@ -80,12 +83,13 @@ Then run the benchmarking script
|
||||
```bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
NUM_PROMPTS=10
|
||||
BACKEND="vllm"
|
||||
DATASET_NAME="sharegpt"
|
||||
DATASET_PATH="<your data path>/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
python3 vllm/benchmarks/benchmark_serving.py --backend ${BACKEND} --model ${MODEL_NAME} --endpoint /v1/completions --dataset-name ${DATASET_NAME} --dataset-path ${DATASET_PATH} --num-prompts ${NUM_PROMPTS}
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--endpoint /v1/completions \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
If successful, you will see the following output
|
||||
@ -122,88 +126,76 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
|
||||
```
|
||||
|
||||
```bash
|
||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||
NUM_PROMPTS=10
|
||||
BACKEND="openai-chat"
|
||||
DATASET_NAME="hf"
|
||||
DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1"
|
||||
DATASET_SPLIT='train'
|
||||
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--backend "${BACKEND}" \
|
||||
--model "${MODEL_NAME}" \
|
||||
--endpoint "/v1/chat/completions" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--hf-split "${DATASET_SPLIT}" \
|
||||
--num-prompts "${NUM_PROMPTS}"
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmarena-ai/VisionArena-Chat \
|
||||
--hf-split train \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
### HuggingFaceDataset Examples
|
||||
### InstructCoder Benchmark with Speculative Decoding
|
||||
|
||||
Currently, HuggingFaceDataset only supports dataset formats
|
||||
similar to `lmms-lab/LLaVA-OneVision-Data` and `Aeala/ShareGPT_Vicuna_unfiltered`. If you need support for other dataset
|
||||
formats, please consider contributing.
|
||||
``` bash
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--speculative-model "[ngram]" \
|
||||
--ngram_prompt_lookup_min 2 \
|
||||
--ngram-prompt-lookup-max 5 \
|
||||
--num_speculative_tokens 5
|
||||
```
|
||||
|
||||
``` bash
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--dataset-name hf \
|
||||
--dataset-path likaixin/InstructCoder \
|
||||
--num-prompts 2048
|
||||
```
|
||||
|
||||
### Other HuggingFaceDataset Examples
|
||||
|
||||
```bash
|
||||
# need a model with vision capability here
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
|
||||
```
|
||||
|
||||
**`lmms-lab/LLaVA-OneVision-Data`**
|
||||
|
||||
```bash
|
||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||
NUM_PROMPTS=10
|
||||
BACKEND="openai-chat"
|
||||
DATASET_NAME="hf"
|
||||
DATASET_PATH="lmms-lab/LLaVA-OneVision-Data"
|
||||
DATASET_SPLIT='train'
|
||||
DATASET_SUBSET='chart2text(cauldron)'
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--backend "${BACKEND}" \
|
||||
--model "${MODEL_NAME}" \
|
||||
--endpoint "/v1/chat/completions" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--hf-split "${DATASET_SPLIT}" \
|
||||
--num-prompts "${NUM_PROMPTS}" \
|
||||
--hf-subset "${DATASET_SUBSET}"
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmms-lab/LLaVA-OneVision-Data \
|
||||
--hf-split train \
|
||||
--hf-subset "chart2text(cauldron)" \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
**`Aeala/ShareGPT_Vicuna_unfiltered`**
|
||||
|
||||
```bash
|
||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||
NUM_PROMPTS=10
|
||||
BACKEND="openai-chat"
|
||||
DATASET_NAME="hf"
|
||||
DATASET_PATH="Aeala/ShareGPT_Vicuna_unfiltered"
|
||||
DATASET_SPLIT='train'
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--backend "${BACKEND}" \
|
||||
--model "${MODEL_NAME}" \
|
||||
--endpoint "/v1/chat/completions" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--hf-split "${DATASET_SPLIT}" \
|
||||
--num-prompts "${NUM_PROMPTS}" \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
--dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
|
||||
--hf-split train \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
---
|
||||
## Example - Offline Throughput Benchmark
|
||||
|
||||
```bash
|
||||
MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
NUM_PROMPTS=10
|
||||
DATASET_NAME="sonnet"
|
||||
DATASET_PATH="vllm/benchmarks/sonnet.txt"
|
||||
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--model "${MODEL_NAME}" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--num-prompts "${NUM_PROMPTS}"
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset-name sonnet \
|
||||
--dataset-path vllm/benchmarks/sonnet.txt \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
If successful, you will see the following output
|
||||
@ -217,19 +209,13 @@ Total num output tokens: 1500
|
||||
### VisionArena Benchmark for Vision Language Models
|
||||
|
||||
``` bash
|
||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||
NUM_PROMPTS=10
|
||||
DATASET_NAME="hf"
|
||||
DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1"
|
||||
DATASET_SPLIT="train"
|
||||
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--model "${MODEL_NAME}" \
|
||||
--backend "vllm-chat" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--num-prompts "${NUM_PROMPTS}" \
|
||||
--hf-split "${DATASET_SPLIT}"
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--backend vllm-chat \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmarena-ai/VisionArena-Chat \
|
||||
--num-prompts 1000 \
|
||||
--hf-split train
|
||||
```
|
||||
|
||||
The `num prompt tokens` now includes image token counts
|
||||
@ -240,29 +226,71 @@ Total num prompt tokens: 14527
|
||||
Total num output tokens: 1280
|
||||
```
|
||||
|
||||
### InstructCoder Benchmark with Speculative Decoding
|
||||
|
||||
``` bash
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_USE_V1=1 \
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--dataset-name=hf \
|
||||
--dataset-path=likaixin/InstructCoder \
|
||||
--model=meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--input-len=1000 \
|
||||
--output-len=100 \
|
||||
--num-prompts=2048 \
|
||||
--async-engine \
|
||||
--speculative-model="[ngram]" \
|
||||
--ngram_prompt_lookup_min=2 \
|
||||
--ngram-prompt-lookup-max=5 \
|
||||
--num_speculative_tokens=5
|
||||
```
|
||||
|
||||
```
|
||||
Throughput: 104.77 requests/s, 23836.22 total tokens/s, 10477.10 output tokens/s
|
||||
Total num prompt tokens: 261136
|
||||
Total num output tokens: 204800
|
||||
```
|
||||
|
||||
### Other HuggingFaceDataset Examples
|
||||
|
||||
**`lmms-lab/LLaVA-OneVision-Data`**
|
||||
|
||||
```bash
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--backend vllm-chat \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmms-lab/LLaVA-OneVision-Data \
|
||||
--hf-split train \
|
||||
--hf-subset "chart2text(cauldron)" \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
**`Aeala/ShareGPT_Vicuna_unfiltered`**
|
||||
|
||||
```bash
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--backend vllm-chat \
|
||||
--dataset-name hf \
|
||||
--dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
|
||||
--hf-split train \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
### Benchmark with LoRA Adapters
|
||||
|
||||
``` bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
MODEL_NAME="meta-llama/Llama-2-7b-hf"
|
||||
BACKEND="vllm"
|
||||
DATASET_NAME="sharegpt"
|
||||
DATASET_PATH="<your data path>/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
NUM_PROMPTS=10
|
||||
MAX_LORAS=2
|
||||
MAX_LORA_RANK=8
|
||||
ENABLE_LORA="--enable-lora"
|
||||
LORA_PATH="yard1/llama-2-7b-sql-lora-test"
|
||||
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--model "${MODEL_NAME}" \
|
||||
--backend "${BACKEND}" \
|
||||
--dataset_path "${DATASET_PATH}" \
|
||||
--dataset_name "${DATASET_NAME}" \
|
||||
--num-prompts "${NUM_PROMPTS}" \
|
||||
--max-loras "${MAX_LORAS}" \
|
||||
--max-lora-rank "${MAX_LORA_RANK}" \
|
||||
${ENABLE_LORA} \
|
||||
--lora-path "${LORA_PATH}"
|
||||
--model meta-llama/Llama-2-7b-hf \
|
||||
--backend vllm \
|
||||
--dataset_path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--dataset_name sharegpt \
|
||||
--num-prompts 10 \
|
||||
--max-loras 2 \
|
||||
--max-lora-rank 8 \
|
||||
--enable-lora \
|
||||
--lora-path yard1/llama-2-7b-sql-lora-test
|
||||
```
|
||||
|
@ -23,7 +23,8 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import Any, Optional, Union
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -239,21 +240,24 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
"""
|
||||
Process a single image input and return a multimedia content dictionary.
|
||||
|
||||
For a PIL.Image.Image input:
|
||||
- Converts the image to RGB.
|
||||
- Saves the image as a JPEG in-memory.
|
||||
- Encodes the JPEG data as a base64 string.
|
||||
- Returns a dictionary with the image as a base64 data URL.
|
||||
Supports three input types:
|
||||
|
||||
For a string input:
|
||||
- Treats the string as a URL or file path.
|
||||
- Prepends "file://" if the string doesn't start with "http://" or
|
||||
"file://".
|
||||
- Returns a dictionary with the image URL.
|
||||
1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
|
||||
containing raw image data. - Loads the bytes as a PIL.Image.Image.
|
||||
|
||||
2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as
|
||||
a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns
|
||||
a dictionary with the image as a base64 data URL.
|
||||
|
||||
3. String input: - Treats the string as a URL or local file path. -
|
||||
Prepends "file://" if the string doesn't start with "http://" or
|
||||
"file://". - Returns a dictionary with the image URL.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input is neither a PIL.Image.Image nor a string.
|
||||
ValueError: If the input is not a supported type.
|
||||
"""
|
||||
if isinstance(image, dict) and 'bytes' in image:
|
||||
image = Image.open(BytesIO(image['bytes']))
|
||||
if isinstance(image, Image.Image):
|
||||
image = image.convert("RGB")
|
||||
with io.BytesIO() as image_data:
|
||||
@ -272,8 +276,8 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
("http://", "file://")) else f"file://{image}")
|
||||
return {"type": "image_url", "image_url": {"url": image_url}}
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid image input {image}. Must be a PIL.Image.Image or str.")
|
||||
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
|
||||
" or str or dictionary with raw image bytes.")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -562,48 +566,56 @@ class BurstGPTDataset(BenchmarkDataset):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HuggingFace Dataset Implementation
|
||||
# HuggingFace Dataset Base Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HuggingFaceDataset(BenchmarkDataset):
|
||||
"""
|
||||
Dataset class for processing a HuggingFace dataset with conversation data
|
||||
and optional images.
|
||||
"""
|
||||
"""Base class for datasets hosted on HuggingFace."""
|
||||
|
||||
SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_path: str,
|
||||
dataset_split: str,
|
||||
dataset_subset: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(dataset_path=dataset_path, **kwargs)
|
||||
|
||||
# Validate dataset path
|
||||
if self.SUPPORTED_DATASET_PATHS and \
|
||||
self.dataset_path not in self.SUPPORTED_DATASET_PATHS:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} "
|
||||
f"only supports: {', '.join(self.SUPPORTED_DATASET_PATHS)}. "
|
||||
"Please consider contributing if you would "
|
||||
"like to add support for additional dataset formats.")
|
||||
|
||||
self.dataset_split = dataset_split
|
||||
self.dataset_subset = dataset_subset
|
||||
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
if not self.dataset_path:
|
||||
raise ValueError("dataset_path must be provided for loading data.")
|
||||
|
||||
"""Load data from HuggingFace datasets."""
|
||||
self.data = load_dataset(
|
||||
self.dataset_path,
|
||||
name=self.dataset_subset,
|
||||
split=self.dataset_split,
|
||||
streaming=True,
|
||||
)
|
||||
if self.data.features is None or "conversations" \
|
||||
not in self.data.features:
|
||||
raise ValueError(
|
||||
"HuggingFaceDataset currently only supports datasets with "
|
||||
"a 'conversations' column like lmms-lab/LLaVA-OneVision-Data. "
|
||||
"Please consider contributing if you would like to add "
|
||||
"support for additional dataset formats.")
|
||||
# Shuffle and filter examples with at least 2 conversations.
|
||||
self.data = self.data.shuffle(seed=self.random_seed).filter(
|
||||
lambda x: len(x["conversations"]) >= 2)
|
||||
self.data = self.data.shuffle(seed=self.random_seed)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Conversation Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ConversationDataset(HuggingFaceDataset):
|
||||
"""Dataset for conversation data with multimodal support."""
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
|
||||
}
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -611,10 +623,13 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs) -> list:
|
||||
# Filter examples with at least 2 conversations
|
||||
filtered_data = self.data.filter(
|
||||
lambda x: len(x["conversations"]) >= 2)
|
||||
sampled_requests = []
|
||||
dynamic_output = output_len is None
|
||||
|
||||
for item in self.data:
|
||||
for item in filtered_data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
conv = item["conversations"]
|
||||
@ -659,29 +674,12 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
"""
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
VISION_ARENA_DATASET_PATH = "lmarena-ai/vision-arena-bench-v0.1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if self.dataset_path != self.VISION_ARENA_DATASET_PATH:
|
||||
raise ValueError(f"Only support Vision Arena dataset.\
|
||||
This data path {self.dataset_path} is not valid.")
|
||||
if self.dataset_subset is None and self.dataset_split != "train":
|
||||
raise ValueError("Dataset split must be 'train'.")
|
||||
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
dataset = load_dataset(
|
||||
self.dataset_path,
|
||||
name=self.dataset_subset,
|
||||
split=self.dataset_split,
|
||||
streaming=True,
|
||||
)
|
||||
self.data = dataset.shuffle(seed=self.random_seed)
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"lmarena-ai/VisionArena-Chat":
|
||||
lambda x: x["conversation"][0][0]["content"],
|
||||
"lmarena-ai/vision-arena-bench-v0.1":
|
||||
lambda x: x["turns"][0][0]["content"]
|
||||
}
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@ -697,7 +695,11 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item["turns"][0][0]["content"]
|
||||
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
|
||||
if parser_fn is None:
|
||||
raise ValueError(
|
||||
f"Unsupported dataset path: {self.dataset_path}")
|
||||
prompt = parser_fn(item)
|
||||
mm_content = process_image(item["images"][0])
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
if enable_multimodal_chat:
|
||||
@ -727,34 +729,15 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
InstructCoder Dataset.
|
||||
https://huggingface.co/datasets/likaixin/InstructCoder
|
||||
|
||||
InstructCoder is the dataset designed for general code editing.
|
||||
It consists of 114,239 instruction-input-output triplets,
|
||||
and covers multiple distinct code editing scenario.
|
||||
InstructCoder is the dataset designed for general code editing. It consists
|
||||
of 114,239 instruction-input-output triplets, and covers multiple distinct
|
||||
code editing scenario.
|
||||
"""
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 200 # this is the average default output length
|
||||
DEFAULT_NUM_REQUESTS = 1000
|
||||
INSTRUCT_CODER_DATASET_PATH = "likaixin/InstructCoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if self.dataset_path != self.INSTRUCT_CODER_DATASET_PATH:
|
||||
raise ValueError(f"Only support likaixin/InstructCoder dataset.\
|
||||
This data path {self.dataset_path} is not valid.")
|
||||
if self.dataset_subset is None and self.dataset_split != "train":
|
||||
raise ValueError("Dataset split must be 'train'.")
|
||||
|
||||
def load_data(self) -> None:
|
||||
dataset = load_dataset(
|
||||
self.dataset_path,
|
||||
name=self.dataset_subset,
|
||||
split=self.dataset_split,
|
||||
streaming=True,
|
||||
)
|
||||
self.data = dataset.shuffle(seed=self.random_seed)
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"likaixin/InstructCoder",
|
||||
}
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
|
@ -49,7 +49,7 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
|
||||
from benchmark_dataset import (BurstGPTDataset, ConversationDataset,
|
||||
InstructCoderDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||
VisionArenaDataset)
|
||||
@ -584,16 +584,17 @@ def main(args: argparse.Namespace):
|
||||
return_prompt_formatted=True)
|
||||
|
||||
elif args.dataset_name == "hf":
|
||||
# Choose between VisionArenaDataset
|
||||
# and HuggingFaceDataset based on provided parameters.
|
||||
dataset_class = HuggingFaceDataset
|
||||
if args.dataset_path == VisionArenaDataset.VISION_ARENA_DATASET_PATH:
|
||||
assert args.hf_subset is None, "VisionArenaDataset needs hf_subset to be None." #noqa: E501
|
||||
# all following datasets are implemented from the
|
||||
# HuggingFaceDataset base class
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = VisionArenaDataset
|
||||
elif args.dataset_path == "likaixin/InstructCoder":
|
||||
args.hf_split = "train"
|
||||
args.hf_subset = None
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = InstructCoderDataset
|
||||
args.hf_split = "train"
|
||||
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ConversationDataset
|
||||
input_requests = dataset_class(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
|
@ -11,7 +11,7 @@ from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
|
||||
from benchmark_dataset import (BurstGPTDataset, ConversationDataset,
|
||||
InstructCoderDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||
VisionArenaDataset)
|
||||
@ -319,21 +319,19 @@ def get_requests(args, tokenizer):
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path == VisionArenaDataset.VISION_ARENA_DATASET_PATH:
|
||||
if args.args.backend == "vllm-chat":
|
||||
raise ValueError(
|
||||
"hf datasets only are supported by vllm-chat backend")
|
||||
# Choose between VisionArenaDataset and HuggingFaceDataset based on
|
||||
# provided parameters.
|
||||
dataset_cls = (VisionArenaDataset if args.dataset_path
|
||||
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
|
||||
and args.hf_subset is None else HuggingFaceDataset)
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ConversationDataset
|
||||
common_kwargs['dataset_subset'] = args.hf_subset
|
||||
common_kwargs['dataset_split'] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path == "likaixin/InstructCoder":
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
@ -469,10 +467,12 @@ def validate_args(args):
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2)
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path == VisionArenaDataset.VISION_ARENA_DATASET_PATH:
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
assert args.backend == "vllm-chat", "VisionArenaDataset needs to use vllm-chat as the backend." #noqa: E501
|
||||
elif args.dataset_path == "likaixin/InstructCoder":
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
assert args.backend == "vllm", "InstructCoder dataset needs to use vllm as the backend." #noqa: E501
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
assert args.backend == "vllm-chat", "ConversationDataset needs to use vllm-chat as the backend." #noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user