[Benchmark] Allow oversample request in benchmark dataset (#15170)
Signed-off-by: Jennifer Zhao <ai.jenniferzhao@gmail.com>
This commit is contained in:
parent
d8c6d7d6b5
commit
b88be22165
@ -42,7 +42,7 @@ become available.
|
|||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td><strong>HuggingFace</strong></td>
|
<td><strong>HuggingFace</strong></td>
|
||||||
<td style="text-align: center;">✅</td>
|
<td style="text-align: center;">🟡</td>
|
||||||
<td style="text-align: center;">🟡</td>
|
<td style="text-align: center;">🟡</td>
|
||||||
<td>Specify your dataset path on HuggingFace</td>
|
<td>Specify your dataset path on HuggingFace</td>
|
||||||
</tr>
|
</tr>
|
||||||
@ -60,8 +60,8 @@ become available.
|
|||||||
🚧: to be supported
|
🚧: to be supported
|
||||||
|
|
||||||
🟡: Partial support. Currently, HuggingFaceDataset only supports dataset formats
|
🟡: Partial support. Currently, HuggingFaceDataset only supports dataset formats
|
||||||
similar to `lmms-lab/LLaVA-OneVision-Data`. If you need support for other dataset
|
similar to `lmms-lab/LLaVA-OneVision-Data` and `Aeala/ShareGPT_Vicuna_unfiltered`.
|
||||||
formats, please consider contributing.
|
If you need support for other dataset formats, please consider contributing.
|
||||||
|
|
||||||
**Note**: VisionArena’s `dataset-name` should be set to `hf`
|
**Note**: VisionArena’s `dataset-name` should be set to `hf`
|
||||||
|
|
||||||
@ -139,6 +139,57 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
|||||||
--num-prompts "${NUM_PROMPTS}"
|
--num-prompts "${NUM_PROMPTS}"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### HuggingFaceDataset Examples
|
||||||
|
|
||||||
|
Currently, HuggingFaceDataset only supports dataset formats
|
||||||
|
similar to `lmms-lab/LLaVA-OneVision-Data` and `Aeala/ShareGPT_Vicuna_unfiltered`. If you need support for other dataset
|
||||||
|
formats, please consider contributing.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# need a model with vision capability here
|
||||||
|
vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
|
||||||
|
```
|
||||||
|
|
||||||
|
**`lmms-lab/LLaVA-OneVision-Data`**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||||
|
NUM_PROMPTS=10
|
||||||
|
BACKEND="openai-chat"
|
||||||
|
DATASET_NAME="hf"
|
||||||
|
DATASET_PATH="lmms-lab/LLaVA-OneVision-Data"
|
||||||
|
DATASET_SPLIT='train'
|
||||||
|
DATASET_SUBSET='chart2text(cauldron)'
|
||||||
|
python3 vllm/benchmarks/benchmark_serving.py \
|
||||||
|
--backend "${BACKEND}" \
|
||||||
|
--model "${MODEL_NAME}" \
|
||||||
|
--endpoint "/v1/chat/completions" \
|
||||||
|
--dataset-name "${DATASET_NAME}" \
|
||||||
|
--dataset-path "${DATASET_PATH}" \
|
||||||
|
--hf-split "${DATASET_SPLIT}" \
|
||||||
|
--num-prompts "${NUM_PROMPTS}" \
|
||||||
|
--hf-subset "${DATASET_SUBSET}"
|
||||||
|
```
|
||||||
|
|
||||||
|
**`Aeala/ShareGPT_Vicuna_unfiltered`**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||||
|
NUM_PROMPTS=10
|
||||||
|
BACKEND="openai-chat"
|
||||||
|
DATASET_NAME="hf"
|
||||||
|
DATASET_PATH="Aeala/ShareGPT_Vicuna_unfiltered"
|
||||||
|
DATASET_SPLIT='train'
|
||||||
|
python3 vllm/benchmarks/benchmark_serving.py \
|
||||||
|
--backend "${BACKEND}" \
|
||||||
|
--model "${MODEL_NAME}" \
|
||||||
|
--endpoint "/v1/chat/completions" \
|
||||||
|
--dataset-name "${DATASET_NAME}" \
|
||||||
|
--dataset-path "${DATASET_PATH}" \
|
||||||
|
--hf-split "${DATASET_SPLIT}" \
|
||||||
|
--num-prompts "${NUM_PROMPTS}" \
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
## Example - Offline Throughput Benchmark
|
## Example - Offline Throughput Benchmark
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ SampleRequest instances, similar to the approach used in ShareGPT.
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
@ -35,6 +36,8 @@ from vllm.lora.utils import get_adapter_absolute_path
|
|||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Data Classes
|
# Data Classes
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@ -61,9 +64,6 @@ class SampleRequest:
|
|||||||
class BenchmarkDataset(ABC):
|
class BenchmarkDataset(ABC):
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
|
|
||||||
# num_requests has default 1000 in both the benchmark_serving.py and
|
|
||||||
# benchmark_throughput.py
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_path: Optional[str] = None,
|
dataset_path: Optional[str] = None,
|
||||||
@ -90,8 +90,8 @@ class BenchmarkDataset(ABC):
|
|||||||
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
|
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Transform a prompt and optional multimodal content into a chat format.
|
Transform a prompt and optional multimodal content into a chat format.
|
||||||
This method is used for chat models that expect a specific
|
This method is used for chat models that expect a specific conversation
|
||||||
conversation format.
|
format.
|
||||||
"""
|
"""
|
||||||
content = [{"text": prompt, "type": "text"}]
|
content = [{"text": prompt, "type": "text"}]
|
||||||
if mm_content is not None:
|
if mm_content is not None:
|
||||||
@ -175,6 +175,24 @@ class BenchmarkDataset(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("sample must be implemented in subclasses.")
|
raise NotImplementedError("sample must be implemented in subclasses.")
|
||||||
|
|
||||||
|
def maybe_oversample_requests(self, requests: list[SampleRequest],
|
||||||
|
num_requests: int) -> None:
|
||||||
|
"""
|
||||||
|
Oversamples the list of requests if its size is less than the desired
|
||||||
|
number.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requests (List[SampleRequest]): The current list of sampled
|
||||||
|
requests. num_requests (int): The target number of requests.
|
||||||
|
"""
|
||||||
|
if len(requests) < num_requests:
|
||||||
|
random.seed(self.random_seed)
|
||||||
|
additional = random.choices(requests,
|
||||||
|
k=num_requests - len(requests))
|
||||||
|
requests.extend(additional)
|
||||||
|
logger.info("Oversampled requests to reach %d total samples.",
|
||||||
|
num_requests)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Utility Functions and Global Caches
|
# Utility Functions and Global Caches
|
||||||
@ -276,15 +294,16 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def sample(self,
|
def sample(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
self,
|
||||||
num_requests: int,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
prefix_len: int = DEFAULT_PREFIX_LEN,
|
num_requests: int,
|
||||||
range_ratio: float = DEFAULT_RANGE_RATIO,
|
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||||
input_len: int = DEFAULT_INPUT_LEN,
|
range_ratio: float = DEFAULT_RANGE_RATIO,
|
||||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
input_len: int = DEFAULT_INPUT_LEN,
|
||||||
**kwargs) -> list[SampleRequest]:
|
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[SampleRequest]:
|
||||||
vocab_size = tokenizer.vocab_size
|
vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
prefix_token_ids = (np.random.randint(
|
prefix_token_ids = (np.random.randint(
|
||||||
@ -346,20 +365,24 @@ class ShareGPTDataset(BenchmarkDataset):
|
|||||||
random.seed(self.random_seed)
|
random.seed(self.random_seed)
|
||||||
random.shuffle(self.data)
|
random.shuffle(self.data)
|
||||||
|
|
||||||
def sample(self,
|
def sample(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
self,
|
||||||
num_requests: int,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
lora_path: Optional[str] = None,
|
num_requests: int,
|
||||||
max_loras: Optional[int] = None,
|
lora_path: Optional[str] = None,
|
||||||
output_len: Optional[int] = None,
|
max_loras: Optional[int] = None,
|
||||||
enable_multimodal_chat: bool = False,
|
output_len: Optional[int] = None,
|
||||||
**kwargs) -> list:
|
enable_multimodal_chat: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> list:
|
||||||
samples: list = []
|
samples: list = []
|
||||||
for entry in self.data:
|
for entry in self.data:
|
||||||
if len(samples) >= num_requests:
|
if len(samples) >= num_requests:
|
||||||
break
|
break
|
||||||
prompt, completion = entry["conversations"][0]["value"],\
|
prompt, completion = (
|
||||||
entry["conversations"][1]["value"]
|
entry["conversations"][0]["value"],
|
||||||
|
entry["conversations"][1]["value"],
|
||||||
|
)
|
||||||
|
|
||||||
lora_request, tokenizer = self.get_random_lora_request(
|
lora_request, tokenizer = self.get_random_lora_request(
|
||||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||||
@ -383,6 +406,7 @@ class ShareGPTDataset(BenchmarkDataset):
|
|||||||
expected_output_len=new_output_len,
|
expected_output_len=new_output_len,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
))
|
))
|
||||||
|
self.maybe_oversample_requests(samples, num_requests)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
@ -415,19 +439,20 @@ class SonnetDataset(BenchmarkDataset):
|
|||||||
with open(self.dataset_path, encoding="utf-8") as f:
|
with open(self.dataset_path, encoding="utf-8") as f:
|
||||||
self.data = f.readlines()
|
self.data = f.readlines()
|
||||||
|
|
||||||
def sample(self,
|
def sample(
|
||||||
tokenizer,
|
self,
|
||||||
num_requests: int,
|
tokenizer,
|
||||||
prefix_len: int = DEFAULT_PREFIX_LEN,
|
num_requests: int,
|
||||||
input_len: int = DEFAULT_INPUT_LEN,
|
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
input_len: int = DEFAULT_INPUT_LEN,
|
||||||
return_prompt_formatted: bool = False,
|
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||||
**kwargs) -> list:
|
return_prompt_formatted: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> list:
|
||||||
# Calculate average token length for a poem line.
|
# Calculate average token length for a poem line.
|
||||||
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
|
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
|
||||||
avg_len = sum(len(tokens)
|
avg_len = sum(len(tokens)
|
||||||
for tokens in \
|
for tokens in tokenized_lines) / len(tokenized_lines)
|
||||||
tokenized_lines) / len(tokenized_lines)
|
|
||||||
|
|
||||||
# Build the base prompt.
|
# Build the base prompt.
|
||||||
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
||||||
@ -506,12 +531,14 @@ class BurstGPTDataset(BenchmarkDataset):
|
|||||||
# Convert the dataframe to a list of lists.
|
# Convert the dataframe to a list of lists.
|
||||||
return data.values.tolist()
|
return data.values.tolist()
|
||||||
|
|
||||||
def sample(self,
|
def sample(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
self,
|
||||||
num_requests: int,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
max_loras: Optional[int] = None,
|
num_requests: int,
|
||||||
lora_path: Optional[str] = None,
|
max_loras: Optional[int] = None,
|
||||||
**kwargs) -> list[SampleRequest]:
|
lora_path: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[SampleRequest]:
|
||||||
samples = []
|
samples = []
|
||||||
data = self._sample_loaded_data(num_requests=num_requests)
|
data = self._sample_loaded_data(num_requests=num_requests)
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
@ -544,7 +571,6 @@ class HuggingFaceDataset(BenchmarkDataset):
|
|||||||
Dataset class for processing a HuggingFace dataset with conversation data
|
Dataset class for processing a HuggingFace dataset with conversation data
|
||||||
and optional images.
|
and optional images.
|
||||||
"""
|
"""
|
||||||
DEFAULT_NUM_REQUESTS = 1000
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -618,6 +644,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
|||||||
expected_output_len=output_len,
|
expected_output_len=output_len,
|
||||||
multi_modal_data=mm_content,
|
multi_modal_data=mm_content,
|
||||||
))
|
))
|
||||||
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
@ -632,7 +659,6 @@ class VisionArenaDataset(HuggingFaceDataset):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_OUTPUT_LEN = 128
|
DEFAULT_OUTPUT_LEN = 128
|
||||||
DEFAULT_NUM_REQUESTS = 1000
|
|
||||||
VISION_ARENA_DATASET_PATH = "lmarena-ai/vision-arena-bench-v0.1"
|
VISION_ARENA_DATASET_PATH = "lmarena-ai/vision-arena-bench-v0.1"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -657,12 +683,14 @@ class VisionArenaDataset(HuggingFaceDataset):
|
|||||||
)
|
)
|
||||||
self.data = dataset.shuffle(seed=self.random_seed)
|
self.data = dataset.shuffle(seed=self.random_seed)
|
||||||
|
|
||||||
def sample(self,
|
def sample(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
self,
|
||||||
num_requests: int,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
output_len: Optional[int] = None,
|
num_requests: int,
|
||||||
enable_multimodal_chat: bool = False,
|
output_len: Optional[int] = None,
|
||||||
**kwargs) -> list:
|
enable_multimodal_chat: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> list:
|
||||||
output_len = (output_len
|
output_len = (output_len
|
||||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||||
sampled_requests = []
|
sampled_requests = []
|
||||||
@ -685,4 +713,5 @@ class VisionArenaDataset(HuggingFaceDataset):
|
|||||||
expected_output_len=output_len,
|
expected_output_len=output_len,
|
||||||
multi_modal_data=mm_content,
|
multi_modal_data=mm_content,
|
||||||
))
|
))
|
||||||
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
Loading…
x
Reference in New Issue
Block a user