diff --git a/benchmarks/README.md b/benchmarks/README.md
index d41de1ca..4777d832 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -41,29 +41,33 @@ become available.
synthetic |
- HuggingFace |
- π‘ |
- π‘ |
- Specify your dataset path on HuggingFace |
+ HuggingFace-VisionArena |
+ β
|
+ β
|
+ lmarena-ai/VisionArena-Chat |
- VisionArena |
+ HuggingFace-InstructCoder |
β
|
β
|
- lmarena-ai/vision-arena-bench-v0.1 (a HuggingFace dataset) |
+ likaixin/InstructCoder |
+
+
+ HuggingFace-Other |
+ β
|
+ β
|
+ lmms-lab/LLaVA-OneVision-Data , Aeala/ShareGPT_Vicuna_unfiltered |
β
: supported
+π‘: Partial support
+
π§: to be supported
-π‘: Partial support. Currently, HuggingFaceDataset only supports dataset formats
-similar to `lmms-lab/LLaVA-OneVision-Data` and `Aeala/ShareGPT_Vicuna_unfiltered`.
-If you need support for other dataset formats, please consider contributing.
-
-**Note**: VisionArenaβs `dataset-name` should be set to `hf`
+**Note**: HuggingFace dataset's `dataset-name` should be set to `hf`
---
## Example - Online Benchmark
@@ -71,8 +75,7 @@ If you need support for other dataset formats, please consider contributing.
First start serving your model
```bash
-MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
-vllm serve ${MODEL_NAME} --disable-log-requests
+vllm serve NousResearch/Hermes-3-Llama-3.1-8B --disable-log-requests
```
Then run the benchmarking script
@@ -80,12 +83,13 @@ Then run the benchmarking script
```bash
# download dataset
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
-MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
-NUM_PROMPTS=10
-BACKEND="vllm"
-DATASET_NAME="sharegpt"
-DATASET_PATH="/ShareGPT_V3_unfiltered_cleaned_split.json"
-python3 vllm/benchmarks/benchmark_serving.py --backend ${BACKEND} --model ${MODEL_NAME} --endpoint /v1/completions --dataset-name ${DATASET_NAME} --dataset-path ${DATASET_PATH} --num-prompts ${NUM_PROMPTS}
+python3 vllm/benchmarks/benchmark_serving.py \
+ --backend vllm \
+ --model NousResearch/Hermes-3-Llama-3.1-8B \
+ --endpoint /v1/completions \
+ --dataset-name sharegpt \
+ --dataset-path /ShareGPT_V3_unfiltered_cleaned_split.json \
+ --num-prompts 10
```
If successful, you will see the following output
@@ -122,88 +126,76 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
```
```bash
-MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
-NUM_PROMPTS=10
-BACKEND="openai-chat"
-DATASET_NAME="hf"
-DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1"
-DATASET_SPLIT='train'
-
python3 vllm/benchmarks/benchmark_serving.py \
- --backend "${BACKEND}" \
- --model "${MODEL_NAME}" \
- --endpoint "/v1/chat/completions" \
- --dataset-name "${DATASET_NAME}" \
- --dataset-path "${DATASET_PATH}" \
- --hf-split "${DATASET_SPLIT}" \
- --num-prompts "${NUM_PROMPTS}"
+ --backend openai-chat \
+ --model Qwen/Qwen2-VL-7B-Instruct \
+ --endpoint /v1/chat/completions \
+ --dataset-name hf \
+ --dataset-path lmarena-ai/VisionArena-Chat \
+ --hf-split train \
+ --num-prompts 1000
```
-### HuggingFaceDataset Examples
+### InstructCoder Benchmark with Speculative Decoding
-Currently, HuggingFaceDataset only supports dataset formats
-similar to `lmms-lab/LLaVA-OneVision-Data` and `Aeala/ShareGPT_Vicuna_unfiltered`. If you need support for other dataset
-formats, please consider contributing.
+``` bash
+VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
+ --speculative-model "[ngram]" \
+ --ngram_prompt_lookup_min 2 \
+ --ngram-prompt-lookup-max 5 \
+ --num_speculative_tokens 5
+```
+
+``` bash
+python3 benchmarks/benchmark_serving.py \
+ --model meta-llama/Meta-Llama-3-8B-Instruct \
+ --dataset-name hf \
+ --dataset-path likaixin/InstructCoder \
+ --num-prompts 2048
+```
+
+### Other HuggingFaceDataset Examples
```bash
-# need a model with vision capability here
vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
```
**`lmms-lab/LLaVA-OneVision-Data`**
```bash
-MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
-NUM_PROMPTS=10
-BACKEND="openai-chat"
-DATASET_NAME="hf"
-DATASET_PATH="lmms-lab/LLaVA-OneVision-Data"
-DATASET_SPLIT='train'
-DATASET_SUBSET='chart2text(cauldron)'
python3 vllm/benchmarks/benchmark_serving.py \
- --backend "${BACKEND}" \
- --model "${MODEL_NAME}" \
- --endpoint "/v1/chat/completions" \
- --dataset-name "${DATASET_NAME}" \
- --dataset-path "${DATASET_PATH}" \
- --hf-split "${DATASET_SPLIT}" \
- --num-prompts "${NUM_PROMPTS}" \
- --hf-subset "${DATASET_SUBSET}"
+ --backend openai-chat \
+ --model Qwen/Qwen2-VL-7B-Instruct \
+ --endpoint /v1/chat/completions \
+ --dataset-name hf \
+ --dataset-path lmms-lab/LLaVA-OneVision-Data \
+ --hf-split train \
+ --hf-subset "chart2text(cauldron)" \
+ --num-prompts 10
```
**`Aeala/ShareGPT_Vicuna_unfiltered`**
```bash
-MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
-NUM_PROMPTS=10
-BACKEND="openai-chat"
-DATASET_NAME="hf"
-DATASET_PATH="Aeala/ShareGPT_Vicuna_unfiltered"
-DATASET_SPLIT='train'
python3 vllm/benchmarks/benchmark_serving.py \
- --backend "${BACKEND}" \
- --model "${MODEL_NAME}" \
- --endpoint "/v1/chat/completions" \
- --dataset-name "${DATASET_NAME}" \
- --dataset-path "${DATASET_PATH}" \
- --hf-split "${DATASET_SPLIT}" \
- --num-prompts "${NUM_PROMPTS}" \
+ --backend openai-chat \
+ --model Qwen/Qwen2-VL-7B-Instruct \
+ --endpoint /v1/chat/completions \
+ --dataset-name hf \
+ --dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
+ --hf-split train \
+ --num-prompts 10
```
---
## Example - Offline Throughput Benchmark
```bash
-MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
-NUM_PROMPTS=10
-DATASET_NAME="sonnet"
-DATASET_PATH="vllm/benchmarks/sonnet.txt"
-
python3 vllm/benchmarks/benchmark_throughput.py \
- --model "${MODEL_NAME}" \
- --dataset-name "${DATASET_NAME}" \
- --dataset-path "${DATASET_PATH}" \
- --num-prompts "${NUM_PROMPTS}"
+ --model NousResearch/Hermes-3-Llama-3.1-8B \
+ --dataset-name sonnet \
+ --dataset-path vllm/benchmarks/sonnet.txt \
+ --num-prompts 10
```
If successful, you will see the following output
@@ -217,19 +209,13 @@ Total num output tokens: 1500
### VisionArena Benchmark for Vision Language Models
``` bash
-MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
-NUM_PROMPTS=10
-DATASET_NAME="hf"
-DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1"
-DATASET_SPLIT="train"
-
python3 vllm/benchmarks/benchmark_throughput.py \
- --model "${MODEL_NAME}" \
- --backend "vllm-chat" \
- --dataset-name "${DATASET_NAME}" \
- --dataset-path "${DATASET_PATH}" \
- --num-prompts "${NUM_PROMPTS}" \
- --hf-split "${DATASET_SPLIT}"
+ --model Qwen/Qwen2-VL-7B-Instruct \
+ --backend vllm-chat \
+ --dataset-name hf \
+ --dataset-path lmarena-ai/VisionArena-Chat \
+ --num-prompts 1000 \
+ --hf-split train
```
The `num prompt tokens` now includes image token counts
@@ -240,29 +226,71 @@ Total num prompt tokens: 14527
Total num output tokens: 1280
```
+### InstructCoder Benchmark with Speculative Decoding
+
+``` bash
+VLLM_WORKER_MULTIPROC_METHOD=spawn \
+VLLM_USE_V1=1 \
+python3 vllm/benchmarks/benchmark_throughput.py \
+ --dataset-name=hf \
+ --dataset-path=likaixin/InstructCoder \
+ --model=meta-llama/Meta-Llama-3-8B-Instruct \
+ --input-len=1000 \
+ --output-len=100 \
+ --num-prompts=2048 \
+ --async-engine \
+ --speculative-model="[ngram]" \
+ --ngram_prompt_lookup_min=2 \
+ --ngram-prompt-lookup-max=5 \
+ --num_speculative_tokens=5
+```
+
+```
+Throughput: 104.77 requests/s, 23836.22 total tokens/s, 10477.10 output tokens/s
+Total num prompt tokens: 261136
+Total num output tokens: 204800
+```
+
+### Other HuggingFaceDataset Examples
+
+**`lmms-lab/LLaVA-OneVision-Data`**
+
+```bash
+python3 vllm/benchmarks/benchmark_throughput.py \
+ --model Qwen/Qwen2-VL-7B-Instruct \
+ --backend vllm-chat \
+ --dataset-name hf \
+ --dataset-path lmms-lab/LLaVA-OneVision-Data \
+ --hf-split train \
+ --hf-subset "chart2text(cauldron)" \
+ --num-prompts 10
+```
+
+**`Aeala/ShareGPT_Vicuna_unfiltered`**
+
+```bash
+python3 vllm/benchmarks/benchmark_throughput.py \
+ --model Qwen/Qwen2-VL-7B-Instruct \
+ --backend vllm-chat \
+ --dataset-name hf \
+ --dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
+ --hf-split train \
+ --num-prompts 10
+```
+
### Benchmark with LoRA Adapters
``` bash
# download dataset
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
-MODEL_NAME="meta-llama/Llama-2-7b-hf"
-BACKEND="vllm"
-DATASET_NAME="sharegpt"
-DATASET_PATH="/ShareGPT_V3_unfiltered_cleaned_split.json"
-NUM_PROMPTS=10
-MAX_LORAS=2
-MAX_LORA_RANK=8
-ENABLE_LORA="--enable-lora"
-LORA_PATH="yard1/llama-2-7b-sql-lora-test"
-
python3 vllm/benchmarks/benchmark_throughput.py \
- --model "${MODEL_NAME}" \
- --backend "${BACKEND}" \
- --dataset_path "${DATASET_PATH}" \
- --dataset_name "${DATASET_NAME}" \
- --num-prompts "${NUM_PROMPTS}" \
- --max-loras "${MAX_LORAS}" \
- --max-lora-rank "${MAX_LORA_RANK}" \
- ${ENABLE_LORA} \
- --lora-path "${LORA_PATH}"
+ --model meta-llama/Llama-2-7b-hf \
+ --backend vllm \
+ --dataset_path /ShareGPT_V3_unfiltered_cleaned_split.json \
+ --dataset_name sharegpt \
+ --num-prompts 10 \
+ --max-loras 2 \
+ --max-lora-rank 8 \
+ --enable-lora \
+ --lora-path yard1/llama-2-7b-sql-lora-test
```
diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py
index 38ef739c..f332566d 100644
--- a/benchmarks/benchmark_dataset.py
+++ b/benchmarks/benchmark_dataset.py
@@ -23,7 +23,8 @@ from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from functools import cache
-from typing import Any, Optional, Union
+from io import BytesIO
+from typing import Any, Callable, Optional, Union
import numpy as np
import pandas as pd
@@ -239,21 +240,24 @@ def process_image(image: Any) -> Mapping[str, Any]:
"""
Process a single image input and return a multimedia content dictionary.
- For a PIL.Image.Image input:
- - Converts the image to RGB.
- - Saves the image as a JPEG in-memory.
- - Encodes the JPEG data as a base64 string.
- - Returns a dictionary with the image as a base64 data URL.
+ Supports three input types:
- For a string input:
- - Treats the string as a URL or file path.
- - Prepends "file://" if the string doesn't start with "http://" or
- "file://".
- - Returns a dictionary with the image URL.
+ 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
+ containing raw image data. - Loads the bytes as a PIL.Image.Image.
+
+ 2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as
+ a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns
+ a dictionary with the image as a base64 data URL.
+
+ 3. String input: - Treats the string as a URL or local file path. -
+ Prepends "file://" if the string doesn't start with "http://" or
+ "file://". - Returns a dictionary with the image URL.
Raises:
- ValueError: If the input is neither a PIL.Image.Image nor a string.
+ ValueError: If the input is not a supported type.
"""
+ if isinstance(image, dict) and 'bytes' in image:
+ image = Image.open(BytesIO(image['bytes']))
if isinstance(image, Image.Image):
image = image.convert("RGB")
with io.BytesIO() as image_data:
@@ -272,8 +276,8 @@ def process_image(image: Any) -> Mapping[str, Any]:
("http://", "file://")) else f"file://{image}")
return {"type": "image_url", "image_url": {"url": image_url}}
- raise ValueError(
- f"Invalid image input {image}. Must be a PIL.Image.Image or str.")
+ raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
+ " or str or dictionary with raw image bytes.")
# -----------------------------------------------------------------------------
@@ -562,48 +566,56 @@ class BurstGPTDataset(BenchmarkDataset):
# -----------------------------------------------------------------------------
-# HuggingFace Dataset Implementation
+# HuggingFace Dataset Base Implementation
# -----------------------------------------------------------------------------
-
-
class HuggingFaceDataset(BenchmarkDataset):
- """
- Dataset class for processing a HuggingFace dataset with conversation data
- and optional images.
- """
+ """Base class for datasets hosted on HuggingFace."""
+
+ SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set()
def __init__(
self,
+ dataset_path: str,
dataset_split: str,
dataset_subset: Optional[str] = None,
**kwargs,
) -> None:
- super().__init__(**kwargs)
+ super().__init__(dataset_path=dataset_path, **kwargs)
+
+ # Validate dataset path
+ if self.SUPPORTED_DATASET_PATHS and \
+ self.dataset_path not in self.SUPPORTED_DATASET_PATHS:
+ raise ValueError(
+ f"{self.__class__.__name__} "
+ f"only supports: {', '.join(self.SUPPORTED_DATASET_PATHS)}. "
+ "Please consider contributing if you would "
+ "like to add support for additional dataset formats.")
+
self.dataset_split = dataset_split
self.dataset_subset = dataset_subset
-
self.load_data()
def load_data(self) -> None:
- if not self.dataset_path:
- raise ValueError("dataset_path must be provided for loading data.")
-
+ """Load data from HuggingFace datasets."""
self.data = load_dataset(
self.dataset_path,
name=self.dataset_subset,
split=self.dataset_split,
streaming=True,
)
- if self.data.features is None or "conversations" \
- not in self.data.features:
- raise ValueError(
- "HuggingFaceDataset currently only supports datasets with "
- "a 'conversations' column like lmms-lab/LLaVA-OneVision-Data. "
- "Please consider contributing if you would like to add "
- "support for additional dataset formats.")
- # Shuffle and filter examples with at least 2 conversations.
- self.data = self.data.shuffle(seed=self.random_seed).filter(
- lambda x: len(x["conversations"]) >= 2)
+ self.data = self.data.shuffle(seed=self.random_seed)
+
+
+# -----------------------------------------------------------------------------
+# Conversation Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class ConversationDataset(HuggingFaceDataset):
+ """Dataset for conversation data with multimodal support."""
+ SUPPORTED_DATASET_PATHS = {
+ 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
+ }
def sample(self,
tokenizer: PreTrainedTokenizerBase,
@@ -611,10 +623,13 @@ class HuggingFaceDataset(BenchmarkDataset):
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
+ # Filter examples with at least 2 conversations
+ filtered_data = self.data.filter(
+ lambda x: len(x["conversations"]) >= 2)
sampled_requests = []
dynamic_output = output_len is None
- for item in self.data:
+ for item in filtered_data:
if len(sampled_requests) >= num_requests:
break
conv = item["conversations"]
@@ -659,29 +674,12 @@ class VisionArenaDataset(HuggingFaceDataset):
"""
DEFAULT_OUTPUT_LEN = 128
- VISION_ARENA_DATASET_PATH = "lmarena-ai/vision-arena-bench-v0.1"
-
- def __init__(
- self,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- if self.dataset_path != self.VISION_ARENA_DATASET_PATH:
- raise ValueError(f"Only support Vision Arena dataset.\
- This data path {self.dataset_path} is not valid.")
- if self.dataset_subset is None and self.dataset_split != "train":
- raise ValueError("Dataset split must be 'train'.")
-
- self.load_data()
-
- def load_data(self) -> None:
- dataset = load_dataset(
- self.dataset_path,
- name=self.dataset_subset,
- split=self.dataset_split,
- streaming=True,
- )
- self.data = dataset.shuffle(seed=self.random_seed)
+ SUPPORTED_DATASET_PATHS = {
+ "lmarena-ai/VisionArena-Chat":
+ lambda x: x["conversation"][0][0]["content"],
+ "lmarena-ai/vision-arena-bench-v0.1":
+ lambda x: x["turns"][0][0]["content"]
+ }
def sample(
self,
@@ -697,7 +695,11 @@ class VisionArenaDataset(HuggingFaceDataset):
for item in self.data:
if len(sampled_requests) >= num_requests:
break
- prompt = item["turns"][0][0]["content"]
+ parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
+ if parser_fn is None:
+ raise ValueError(
+ f"Unsupported dataset path: {self.dataset_path}")
+ prompt = parser_fn(item)
mm_content = process_image(item["images"][0])
prompt_len = len(tokenizer(prompt).input_ids)
if enable_multimodal_chat:
@@ -727,34 +729,15 @@ class InstructCoderDataset(HuggingFaceDataset):
InstructCoder Dataset.
https://huggingface.co/datasets/likaixin/InstructCoder
- InstructCoder is the dataset designed for general code editing.
- It consists of 114,239 instruction-input-output triplets,
- and covers multiple distinct code editing scenario.
+ InstructCoder is the dataset designed for general code editing. It consists
+ of 114,239 instruction-input-output triplets, and covers multiple distinct
+ code editing scenario.
"""
DEFAULT_OUTPUT_LEN = 200 # this is the average default output length
- DEFAULT_NUM_REQUESTS = 1000
- INSTRUCT_CODER_DATASET_PATH = "likaixin/InstructCoder"
-
- def __init__(
- self,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- if self.dataset_path != self.INSTRUCT_CODER_DATASET_PATH:
- raise ValueError(f"Only support likaixin/InstructCoder dataset.\
- This data path {self.dataset_path} is not valid.")
- if self.dataset_subset is None and self.dataset_split != "train":
- raise ValueError("Dataset split must be 'train'.")
-
- def load_data(self) -> None:
- dataset = load_dataset(
- self.dataset_path,
- name=self.dataset_subset,
- split=self.dataset_split,
- streaming=True,
- )
- self.data = dataset.shuffle(seed=self.random_seed)
+ SUPPORTED_DATASET_PATHS = {
+ "likaixin/InstructCoder",
+ }
def sample(self,
tokenizer: PreTrainedTokenizerBase,
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
index e2f712df..dabf2214 100644
--- a/benchmarks/benchmark_serving.py
+++ b/benchmarks/benchmark_serving.py
@@ -49,7 +49,7 @@ try:
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
-from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
+from benchmark_dataset import (BurstGPTDataset, ConversationDataset,
InstructCoderDataset, RandomDataset,
SampleRequest, ShareGPTDataset, SonnetDataset,
VisionArenaDataset)
@@ -584,16 +584,17 @@ def main(args: argparse.Namespace):
return_prompt_formatted=True)
elif args.dataset_name == "hf":
- # Choose between VisionArenaDataset
- # and HuggingFaceDataset based on provided parameters.
- dataset_class = HuggingFaceDataset
- if args.dataset_path == VisionArenaDataset.VISION_ARENA_DATASET_PATH:
- assert args.hf_subset is None, "VisionArenaDataset needs hf_subset to be None." #noqa: E501
+ # all following datasets are implemented from the
+ # HuggingFaceDataset base class
+ if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_class = VisionArenaDataset
- elif args.dataset_path == "likaixin/InstructCoder":
+ args.hf_split = "train"
+ args.hf_subset = None
+ elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_class = InstructCoderDataset
args.hf_split = "train"
-
+ elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
+ dataset_class = ConversationDataset
input_requests = dataset_class(
dataset_path=args.dataset_path,
dataset_subset=args.hf_subset,
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index f2f68b0d..1ff63f0a 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -11,7 +11,7 @@ from typing import Any, Optional, Union
import torch
import uvloop
-from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
+from benchmark_dataset import (BurstGPTDataset, ConversationDataset,
InstructCoderDataset, RandomDataset,
SampleRequest, ShareGPTDataset, SonnetDataset,
VisionArenaDataset)
@@ -319,21 +319,19 @@ def get_requests(args, tokenizer):
elif args.dataset_name == "burstgpt":
dataset_cls = BurstGPTDataset
elif args.dataset_name == "hf":
- if args.dataset_path == VisionArenaDataset.VISION_ARENA_DATASET_PATH:
- if args.args.backend == "vllm-chat":
- raise ValueError(
- "hf datasets only are supported by vllm-chat backend")
- # Choose between VisionArenaDataset and HuggingFaceDataset based on
- # provided parameters.
- dataset_cls = (VisionArenaDataset if args.dataset_path
- == VisionArenaDataset.VISION_ARENA_DATASET_PATH
- and args.hf_subset is None else HuggingFaceDataset)
+ if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
+ dataset_cls = VisionArenaDataset
+ common_kwargs['dataset_subset'] = None
+ common_kwargs['dataset_split'] = "train"
+ sample_kwargs["enable_multimodal_chat"] = True
+ elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
+ dataset_cls = InstructCoderDataset
+ common_kwargs['dataset_split'] = "train"
+ elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
+ dataset_cls = ConversationDataset
common_kwargs['dataset_subset'] = args.hf_subset
common_kwargs['dataset_split'] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
- elif args.dataset_path == "likaixin/InstructCoder":
- dataset_cls = InstructCoderDataset
- common_kwargs['dataset_split'] = "train"
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
@@ -469,10 +467,12 @@ def validate_args(args):
since --dataset-name is not 'hf'.",
stacklevel=2)
elif args.dataset_name == "hf":
- if args.dataset_path == VisionArenaDataset.VISION_ARENA_DATASET_PATH:
+ if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
assert args.backend == "vllm-chat", "VisionArenaDataset needs to use vllm-chat as the backend." #noqa: E501
- elif args.dataset_path == "likaixin/InstructCoder":
+ elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
assert args.backend == "vllm", "InstructCoder dataset needs to use vllm as the backend." #noqa: E501
+ elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
+ assert args.backend == "vllm-chat", "ConversationDataset needs to use vllm-chat as the backend." #noqa: E501
else:
raise ValueError(
f"{args.dataset_path} is not supported by hf dataset.")