From 1253b1577408f7981d11495b1fda71cbcbe48dc4 Mon Sep 17 00:00:00 2001 From: Jennifer Zhao Date: Mon, 10 Mar 2025 00:23:11 -0700 Subject: [PATCH] [Feature] Consolidate performance benchmark datasets (#14036) Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> Signed-off-by: Roger Wang Co-authored-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> Co-authored-by: Roger Wang --- benchmarks/benchmark_dataset.py | 667 +++++++++++++++++++++++++++++ benchmarks/benchmark_serving.py | 459 ++++---------------- benchmarks/benchmark_throughput.py | 278 ++++-------- 3 files changed, 825 insertions(+), 579 deletions(-) create mode 100644 benchmarks/benchmark_dataset.py diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py new file mode 100644 index 00000000..30fffdda --- /dev/null +++ b/benchmarks/benchmark_dataset.py @@ -0,0 +1,667 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This module defines a framework for sampling benchmark requests from various +datasets. Each dataset subclass of BenchmarkDataset must implement sample +generation. Supported dataset types include: + - ShareGPT + - Random (synthetic) + - Sonnet + - BurstGPT + - HuggingFace + - VisionArena + +TODO: Implement CustomDataset to parse a JSON file and convert its contents into +SampleRequest instances, similar to the approach used in ShareGPT. +""" + +import base64 +import io +import json +import random +from abc import ABC, abstractmethod +from collections.abc import Mapping +from dataclasses import dataclass +from functools import cache +from typing import Any, Optional, Union + +import numpy as np +import pandas as pd +from datasets import load_dataset +from PIL import Image +from transformers import PreTrainedTokenizerBase + +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path +from vllm.multimodal import MultiModalDataDict +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer + +# ----------------------------------------------------------------------------- +# Data Classes +# ----------------------------------------------------------------------------- + + +@dataclass +class SampleRequest: + """ + Represents a single inference request for benchmarking. + """ + + prompt: str + prompt_len: int + expected_output_len: int + multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None + lora_request: Optional[LoRARequest] = None + + +# ----------------------------------------------------------------------------- +# Benchmark Dataset Base Class +# ----------------------------------------------------------------------------- + + +class BenchmarkDataset(ABC): + DEFAULT_SEED = 0 + + # num_requests has default 1000 in both the benchmark_serving.py and + # benchmark_throughput.py + + def __init__( + self, + dataset_path: Optional[str] = None, + random_seed: int = DEFAULT_SEED, + ) -> None: + """ + Initialize the BenchmarkDataset with an optional dataset path and random + seed. Args: + dataset_path (Optional[str]): Path to the dataset. If None, it + indicates that a default or random dataset might be used. + random_seed (int): Seed value for reproducible shuffling or + sampling. Defaults to DEFAULT_SEED. + """ + self.dataset_path = dataset_path + # Set the random seed, ensuring that a None value is replaced with the + # default seed. + self.random_seed = (random_seed + if random_seed is not None else self.DEFAULT_SEED) + self.data = None + + def load_data(self) -> None: + """ + Load data from the dataset path into self.data. + + This method must be overridden by subclasses since the method to load + data will vary depending on the dataset format and source. + + Raises: + NotImplementedError: If a subclass does not implement this method. + """ + # TODO (jenniferzhao): add support for downloading data + raise NotImplementedError( + "load_data must be implemented in subclasses.") + + def get_random_lora_request( + self, + tokenizer: PreTrainedTokenizerBase, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + ) -> tuple[Optional[LoRARequest], AnyTokenizer]: + """ + Optionally select a random LoRA request and return its associated + tokenizer. + + This method is used when LoRA parameters are provided. It randomly + selects a LoRA based on max_loras and retrieves a cached tokenizer for + that LoRA if available. Otherwise, it returns the base tokenizer. + + Args: + tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no + LoRA is selected. max_loras (Optional[int]): The maximum number of + LoRAs available. If None, LoRA is not used. lora_path + (Optional[str]): Path to the LoRA parameters on disk. If None, LoRA + is not used. + + Returns: + tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first + element is a LoRARequest (or None if not applicable) and the second + element is the tokenizer associated with the LoRA request (or the + base tokenizer). + """ + if max_loras is None or lora_path is None: + return None, tokenizer + + # Generate a random LoRA ID in the range [1, max_loras]. + lora_id = random.randint(1, max_loras) + lora_request = LoRARequest( + lora_name=str(lora_id), + lora_int_id=lora_id, + lora_path=lora_path_on_disk(lora_path), + ) + if lora_id not in lora_tokenizer_cache: + lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) + # Return lora_request and the cached tokenizer if available; otherwise, + # return the base tokenizer + return lora_request, lora_tokenizer_cache[lora_id] or tokenizer + + @abstractmethod + def sample(self, tokenizer: PreTrainedTokenizerBase, + num_requests: int) -> list[SampleRequest]: + """ + Abstract method to generate sample requests from the dataset. + + Subclasses must override this method to implement dataset-specific logic + for generating a list of SampleRequest objects. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used + for processing the dataset's text. + num_requests (int): The number of sample requests to generate. + + Returns: + list[SampleRequest]: A list of sample requests generated from the + dataset. + """ + raise NotImplementedError("sample must be implemented in subclasses.") + + +# ----------------------------------------------------------------------------- +# Utility Functions and Global Caches +# ----------------------------------------------------------------------------- + + +def is_valid_sequence( + prompt_len: int, + output_len: int, + min_len: int = 4, + max_prompt_len: int = 1024, + max_total_len: int = 2048, + skip_min_output_len_check: bool = False, +) -> bool: + """ + Validate a sequence based on prompt and output lengths. + + Default pruning criteria are copied from the original `sample_hf_requests` + and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as + from `sample_requests` in benchmark_throughput.py. + """ + # Check for invalid conditions + prompt_too_short = prompt_len < min_len + output_too_short = (not skip_min_output_len_check) and (output_len + < min_len) + prompt_too_long = prompt_len > max_prompt_len + combined_too_long = (prompt_len + output_len) > max_total_len + + # Return True if none of the invalid conditions are met + return not (prompt_too_short or output_too_short or prompt_too_long + or combined_too_long) + + +@cache +def lora_path_on_disk(lora_path: str) -> str: + return get_adapter_absolute_path(lora_path) + + +# Global cache for LoRA tokenizers. +lora_tokenizer_cache: dict[int, AnyTokenizer] = {} + + +def process_image(image: Any) -> Mapping[str, Any]: + """ + Process a single image input and return a multimedia content dictionary. + + For a PIL.Image.Image input: + - Converts the image to RGB. + - Saves the image as a JPEG in-memory. + - Encodes the JPEG data as a base64 string. + - Returns a dictionary with the image as a base64 data URL. + + For a string input: + - Treats the string as a URL or file path. + - Prepends "file://" if the string doesn't start with "http://" or + "file://". + - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is neither a PIL.Image.Image nor a string. + """ + if isinstance(image, Image.Image): + image = image.convert("RGB") + with io.BytesIO() as image_data: + image.save(image_data, format="JPEG") + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + return { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + if isinstance(image, str): + image_url = (image if image.startswith( + ("http://", "file://")) else f"file://{image}") + return {"type": "image_url", "image_url": {"url": image_url}} + + raise ValueError( + f"Invalid image input {image}. Must be a PIL.Image.Image or str.") + + +# ----------------------------------------------------------------------------- +# Random Dataset Implementation (Synthetic Data) +# ----------------------------------------------------------------------------- + + +class RandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the random dataset. + DEFAULT_PREFIX_LEN = 0 + DEFAULT_RANGE_RATIO = 1.0 + DEFAULT_INPUT_LEN = 1024 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + range_ratio: float = DEFAULT_RANGE_RATIO, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + **kwargs) -> list[SampleRequest]: + + vocab_size = tokenizer.vocab_size + + prefix_token_ids = (np.random.randint( + 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + + input_low = int(input_len * range_ratio) + output_low = int(output_len * range_ratio) + + input_lens = np.random.randint(input_low, + input_len + 1, + size=num_requests) + output_lens = np.random.randint(output_low, + output_len + 1, + size=num_requests) + offsets = np.random.randint(0, vocab_size, size=num_requests) + + requests = [] + for i in range(num_requests): + inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % + vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + prompt = tokenizer.decode(token_sequence) + total_input_len = prefix_len + int(input_lens[i]) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + )) + return requests + + +# ----------------------------------------------------------------------------- +# ShareGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ShareGPTDataset(BenchmarkDataset): + """ + Implements the ShareGPT dataset. Loads data from a JSON file and generates + sample requests based on conversation turns. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + with open(self.dataset_path, encoding="utf-8") as f: + self.data = json.load(f) + # Filter entries with at least two conversation turns. + self.data = [ + entry for entry in self.data + if "conversations" in entry and len(entry["conversations"]) >= 2 + ] + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + **kwargs) -> list: + samples: list = [] + for entry in self.data: + if len(samples) >= num_requests: + break + prompt, completion = entry["conversations"][0]["value"],\ + entry["conversations"][1]["value"] + + lora_request, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + new_output_len = (len(completion_ids) + if output_len is None else output_len) + if not is_valid_sequence(prompt_len, + new_output_len, + skip_min_output_len_check=output_len + is not None): + continue + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=new_output_len, + lora_request=lora_request, + )) + return samples + + +# ----------------------------------------------------------------------------- +# Sonnet Dataset Implementation +# ----------------------------------------------------------------------------- + + +class SonnetDataset(BenchmarkDataset): + """ + Simplified implementation of the Sonnet dataset. Loads poem lines from a + text file and generates sample requests. Default values here copied from + `benchmark_serving.py` for the sonnet dataset. + """ + + DEFAULT_PREFIX_LEN = 200 + DEFAULT_INPUT_LEN = 550 + DEFAULT_OUTPUT_LEN = 150 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if not self.dataset_path: + raise ValueError("dataset_path must be provided.") + with open(self.dataset_path, encoding="utf-8") as f: + self.data = f.readlines() + + def sample(self, + tokenizer, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + return_prompt_formatted: bool = False, + **kwargs) -> list: + # Calculate average token length for a poem line. + tokenized_lines = [tokenizer(line).input_ids for line in self.data] + avg_len = sum(len(tokens) + for tokens in \ + tokenized_lines) / len(tokenized_lines) + + # Build the base prompt. + base_prompt = "Pick as many lines as you can from these poem lines:\n" + base_msg = [{"role": "user", "content": base_prompt}] + base_fmt = tokenizer.apply_chat_template(base_msg, + add_generation_prompt=True, + tokenize=False) + base_offset = len(tokenizer(base_fmt).input_ids) + if input_len <= base_offset: + raise ValueError( + f"'input_len' must be higher than the base prompt length " + f"({base_offset}).") + + # Determine how many poem lines to use. + num_input_lines = round((input_len - base_offset) / avg_len) + num_prefix_lines = round((prefix_len - base_offset) / avg_len) + prefix_lines = self.data[:num_prefix_lines] + + samples = [] + for _ in range(num_requests): + extra_lines = random.choices(self.data, + k=num_input_lines - num_prefix_lines) + prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" + msg = [{"role": "user", "content": prompt}] + prompt_formatted = tokenizer.apply_chat_template( + msg, add_generation_prompt=True, tokenize=False) + prompt_len = len(tokenizer(prompt_formatted).input_ids) + samples.append( + SampleRequest( + prompt=prompt_formatted + if return_prompt_formatted else prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + return samples + + +# ----------------------------------------------------------------------------- +# BurstGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class BurstGPTDataset(BenchmarkDataset): + """ + Implements the BurstGPT dataset. Loads data from a CSV file and generates + sample requests based on synthetic prompt generation. Only rows with Model + "GPT-4" and positive response tokens are used. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self, ): + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + df = pd.read_csv(self.dataset_path) + # Filter to keep only GPT-4 rows. + gpt4_df = df[df["Model"] == "GPT-4"] + # Remove failed requests (where Response tokens is 0 or less). + gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] + # Sample the desired number of rows. + self.data = gpt4_df + + def _sample_loaded_data(self, num_requests: int) -> list: + if num_requests <= len(self.data): + data = self.data.sample(n=num_requests, + random_state=self.random_seed) + else: + data = self.data.sample( + n=num_requests, + random_state=self.random_seed, + replace=True, + ) + # Convert the dataframe to a list of lists. + return data.values.tolist() + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + **kwargs) -> list[SampleRequest]: + samples = [] + data = self._sample_loaded_data(num_requests=num_requests) + for i in range(num_requests): + input_len = int(data[i][2]) + output_len = int(data[i][3]) + lora_req, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + vocab_size = tokenizer.vocab_size + # Generate a synthetic prompt: a list of token IDs computed as (i + + # j) modulo vocab_size. + token_ids = [(i + j) % vocab_size for j in range(input_len)] + prompt = tokenizer.decode(token_ids) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=output_len, + lora_request=lora_req, + )) + return samples + + +# ----------------------------------------------------------------------------- +# HuggingFace Dataset Implementation +# ----------------------------------------------------------------------------- + + +class HuggingFaceDataset(BenchmarkDataset): + """ + Dataset class for processing a HuggingFace dataset with conversation data + and optional images. + """ + DEFAULT_NUM_REQUESTS = 1000 + + def __init__( + self, + dataset_split: str, + dataset_subset: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataset_split = dataset_split + self.dataset_subset = dataset_subset + + self.load_data() + + def load_data(self) -> None: + if not self.dataset_path: + raise ValueError("dataset_path must be provided for loading data.") + + self.data = load_dataset( + self.dataset_path, + name=self.dataset_subset, + split=self.dataset_split, + streaming=True, + ) + + if "conversations" not in self.data.features: + raise ValueError("HF Dataset must have a 'conversations' column.") + + # Shuffle and filter examples with at least 2 conversations. + self.data = self.data.shuffle(seed=self.random_seed).filter( + lambda x: len(x["conversations"]) >= 2) + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + **kwargs) -> list: + sampled_requests = [] + dynamic_output = output_len is None + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + + conv = item["conversations"] + prompt, completion = conv[0]["value"], conv[1]["value"] + + lora_request, tokenizer = self.get_random_lora_request( + tokenizer, lora_path=lora_path, max_loras=max_loras) + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence( + prompt_len, completion_len): + continue + + mm_content = process_image( + item["image"]) if "image" in item else None + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + lora_request=lora_request, + )) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Vision Arena Dataset Implementation +# ----------------------------------------------------------------------------- + + +class VisionArenaDataset(BenchmarkDataset): + """ + Vision Arena Dataset. + """ + + DEFAULT_OUTPUT_LEN = 128 + DEFAULT_NUM_REQUESTS = 1000 + VISION_ARENA_DATASET_PATH = "lmarena-ai/vision-arena-bench-v0.1" + + def __init__( + self, + dataset_split: str, + dataset_subset: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataset_split = dataset_split + self.dataset_subset = dataset_subset + + if self.dataset_path != self.VISION_ARENA_DATASET_PATH: + raise ValueError(f"Only support Vision Arena dataset.\ + This data path {self.dataset_path} is not valid.") + if self.dataset_subset is None and self.dataset_split != "train": + raise ValueError("Dataset split must be 'train'.") + + self.load_data() + + def load_data(self) -> None: + dataset = load_dataset( + self.dataset_path, + name=self.dataset_subset, + split=self.dataset_split, + streaming=True, + ) + self.data = dataset.shuffle(seed=self.random_seed) + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: int = DEFAULT_OUTPUT_LEN, + **kwargs) -> list: + # TODO (jenniferzhao): Add support for offline benchmark sampling + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["turns"][0][0]["content"] + prompt_len = len(tokenizer(prompt).input_ids) + mm_content = process_image(item["images"][0]) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + return sampled_requests diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index b95c8b14..1dd01ca9 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -25,25 +25,20 @@ On the client side, run: """ import argparse import asyncio -import base64 import gc -import io import json import os import random import time import warnings -from collections.abc import AsyncGenerator, Collection +from collections.abc import AsyncGenerator, Iterable from dataclasses import dataclass from datetime import datetime from typing import Any, Optional import numpy as np -import pandas as pd from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, RequestFuncOutput) -from datasets import load_dataset -from PIL.Image import Image from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase @@ -57,6 +52,9 @@ try: except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser +from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset, + RandomDataset, SampleRequest, ShareGPTDataset, + SonnetDataset, VisionArenaDataset) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -92,325 +90,18 @@ class BenchmarkMetrics: percentiles_e2el_ms: list[tuple[float, float]] -def sample_sharegpt_requests( - dataset_path: str, - num_requests: int, - tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int] = None, -) -> list[tuple[str, int, int, None]]: - # Load the dataset. - with open(dataset_path, encoding='utf-8') as f: - dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], - data["conversations"][1]["value"]) for data in dataset] - - # Shuffle the dataset. - random.shuffle(dataset) - - # Filter out sequences that are too long or too short - filtered_dataset: list[tuple[str, int, int]] = [] - for i in range(len(dataset)): - if len(filtered_dataset) == num_requests: - break - - # Tokenize the prompts and completions. - prompt = dataset[i][0] - prompt_token_ids = tokenizer(prompt).input_ids - completion = dataset[i][1] - completion_token_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_token_ids) - output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or (fixed_output_len is None and output_len < 4): - # Prune too short sequences. - continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. - continue - filtered_dataset.append((prompt, prompt_len, output_len, None)) - - return filtered_dataset - - -def sample_burstgpt_requests( - dataset_path: str, - num_requests: int, - random_seed: int, - tokenizer: PreTrainedTokenizerBase, -) -> list[tuple[str, int, int, None]]: - df = pd.read_csv(dataset_path) - gpt4_df = df[df["Model"] == "GPT-4"] - # Remove the failed requests (i.e., response length is 0) - gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] - # Randomly sample num_requests from the dataset - if num_requests <= len(gpt4_df): - gpt4_df = gpt4_df.sample(n=num_requests, random_state=random_seed) - else: - gpt4_df = gpt4_df.sample(n=num_requests, - random_state=random_seed, - replace=True) - # Convert the dataframe to a list of tuples - dataset = gpt4_df.values.tolist() - input_requests = [] - for i in range(num_requests): - input_len = int(dataset[i][2]) - output_len = int(dataset[i][3]) - prompt = tokenizer.decode([(i + j) % tokenizer.vocab_size - for j in range(input_len)]) - input_requests.append((prompt, input_len, output_len, None)) - return input_requests - - -def sample_sonnet_requests( - dataset_path: str, - num_requests: int, - input_len: int, - output_len: int, - prefix_len: int, - tokenizer: PreTrainedTokenizerBase, -) -> list[tuple[str, str, int, int, None]]: - assert ( - input_len > prefix_len - ), "'args.sonnet-input-len' must be greater than 'args.sonnet-prefix-len'." - - # Load the dataset. - with open(dataset_path, encoding='utf-8') as f: - poem_lines = f.readlines() - - # Tokenize the poem lines. - poem_token_ids = tokenizer(poem_lines).input_ids - average_poem_len = sum( - len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids) - - # Base prefix for all requests. - base_prompt = "Pick as many lines as you can from these poem lines:\n" - base_message = [{ - "role": "user", - "content": base_prompt, - }] - base_prompt_formatted = tokenizer.apply_chat_template( - base_message, add_generation_prompt=True, tokenize=False) - base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids) - - assert ( - input_len > base_prompt_offset - ), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}." - num_input_lines = round( - (input_len - base_prompt_offset) / average_poem_len) - - # First approximately `prefix_len` number of tokens in the - # prompt are fixed poem lines. - assert ( - prefix_len > base_prompt_offset - ), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}." - - num_prefix_lines = round( - (prefix_len - base_prompt_offset) / average_poem_len) - prefix_lines = poem_lines[:num_prefix_lines] - - # Sample the rest of lines per request. - sampled_requests: list[tuple[str, int, int]] = [] - for _ in range(num_requests): - num_lines_needed = num_input_lines - num_prefix_lines - sampled_lines = "".join(prefix_lines + - random.choices(poem_lines, k=num_lines_needed)) - - prompt = f"{base_prompt}{sampled_lines}" - message = [ - { - "role": "user", - "content": prompt, - }, - ] - prompt_formatted = tokenizer.apply_chat_template( - message, add_generation_prompt=True, tokenize=False) - prompt_len = len(tokenizer(prompt_formatted).input_ids) - sampled_requests.append( - (prompt, prompt_formatted, prompt_len, output_len, None)) - - return sampled_requests - - -def sample_vision_arena_requests( - dataset, - num_requests: int, - tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int] = None, -) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]: - sampled_requests: list[tuple[str, int, int, dict[str, - Collection[str]]]] = [] - for data in dataset: - if len(sampled_requests) == num_requests: - break - - prompt = data["turns"][0][0]['content'] - - prompt_token_ids = tokenizer(prompt).input_ids - if fixed_output_len is None: - # Default max output len is set to 128 - print("--hf-output-len is not provided. Using default value 128.") - fixed_output_len = 128 - - prompt_len = len(prompt_token_ids) - output_len = fixed_output_len - - assert isinstance( - data["images"][0], - Image), ("Input image format must be `PIL.Image.Image`, " - f"given {type(data['image'])}.") - image: Image = data["images"][0] - image = image.convert("RGB") - image_data = io.BytesIO() - image.save(image_data, format='JPEG') - image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") - mm_content = { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_base64}" - }, - } - - sampled_requests.append((prompt, prompt_len, output_len, mm_content)) - - return sampled_requests - - -def sample_hf_requests( - dataset_path: str, - dataset_subset: Optional[str], - dataset_split: str, - num_requests: int, - tokenizer: PreTrainedTokenizerBase, - random_seed: int, - fixed_output_len: Optional[int] = None, -) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]: - - # Special case for vision_arena dataset - if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \ - and dataset_subset is None: - assert dataset_split == "train" - dataset = load_dataset(dataset_path, - name=dataset_subset, - split=dataset_split, - streaming=True) - dataset = dataset.shuffle(seed=random_seed) - return sample_vision_arena_requests(dataset, num_requests, tokenizer, - fixed_output_len) - - dataset = load_dataset(dataset_path, - name=dataset_subset, - split=dataset_split, - streaming=True) - assert "conversations" in dataset.features, ( - "HF Dataset must have 'conversations' column.") - filter_func = lambda x: len(x["conversations"]) >= 2 - filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func) - sampled_requests: list[tuple[str, int, int, dict[str, - Collection[str]]]] = [] - for data in filtered_dataset: - if len(sampled_requests) == num_requests: - break - - # Tokenize the prompts and completions. - prompt = data["conversations"][0]["value"] - prompt_token_ids = tokenizer(prompt).input_ids - completion = data["conversations"][1]["value"] - completion_token_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_token_ids) - output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len - if fixed_output_len is None and (prompt_len < 4 or output_len < 4): - # Prune too short sequences. - continue - if fixed_output_len is None and \ - (prompt_len > 1024 or prompt_len + output_len > 2048): - # Prune too long sequences. - continue - - if "image" in data and isinstance(data["image"], Image): - image: Image = data["image"] - image = image.convert("RGB") - image_data = io.BytesIO() - image.save(image_data, format='JPEG') - image_base64 = base64.b64encode( - image_data.getvalue()).decode("utf-8") - mm_content = { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_base64}" - }, - } - elif "image" in data and isinstance(data["image"], str): - if (data["image"].startswith("http://") or \ - data["image"].startswith("file://")): - image_url = data["image"] - else: - image_url = f"file://{data['image']}" - - mm_content = { - "type": "image_url", - "image_url": { - "url": image_url - }, - } - else: - mm_content = None - - sampled_requests.append((prompt, prompt_len, output_len, mm_content)) - - return sampled_requests - - -def sample_random_requests( - prefix_len: int, - input_len: int, - output_len: int, - num_prompts: int, - range_ratio: float, - tokenizer: PreTrainedTokenizerBase, -) -> list[tuple[str, int, int]]: - prefix_token_ids = np.random.randint(0, - tokenizer.vocab_size, - size=prefix_len).tolist() - - input_lens = np.random.randint( - int(input_len * range_ratio), - input_len + 1, - size=num_prompts, - ) - output_lens = np.random.randint( - int(output_len * range_ratio), - output_len + 1, - size=num_prompts, - ) - offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) - input_requests = [] - for i in range(num_prompts): - prompt = tokenizer.decode(prefix_token_ids + - [(offsets[i] + i + j) % tokenizer.vocab_size - for j in range(input_lens[i])]) - - input_requests.append((prompt, int(prefix_len + input_lens[i]), - int(output_lens[i]), None)) - - return input_requests - - async def get_request( - input_requests: list[tuple[str, int, int]], + input_requests: list[SampleRequest], request_rate: float, burstiness: float = 1.0, -) -> AsyncGenerator[tuple[str, int, int], None]: +) -> AsyncGenerator[SampleRequest, None]: """ Asynchronously generates requests at a specified rate with OPTIONAL burstiness. Args: input_requests: - A list of input requests, each represented as a tuple. + A list of input requests, each represented as a SampleRequest. request_rate: The rate at which requests are generated (requests/s). burstiness (optional): @@ -422,7 +113,7 @@ async def get_request( in more bursty requests, while a higher burstiness value (burstiness > 1) results in a more uniform arrival of requests. """ - input_requests = iter(input_requests) + input_requests: Iterable[SampleRequest] = iter(input_requests) # Calculate scale parameter theta to maintain the desired request_rate. assert burstiness > 0, ( @@ -444,7 +135,7 @@ async def get_request( def calculate_metrics( - input_requests: list[tuple[str, int, int]], + input_requests: list[SampleRequest], outputs: list[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, @@ -475,7 +166,7 @@ def calculate_metrics( tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids) actual_output_lens.append(output_len) - total_input += input_requests[i][1] + total_input += input_requests[i].prompt_len tpot = 0 if output_len > 1: latency_minus_ttft = outputs[i].latency - outputs[i].ttft @@ -558,18 +249,18 @@ async def benchmark( model_id: str, model_name: str, tokenizer: PreTrainedTokenizerBase, - input_requests: list[tuple[str, int, int]], + input_requests: list[SampleRequest], logprobs: Optional[int], request_rate: float, burstiness: float, disable_tqdm: bool, profile: bool, selected_percentile_metrics: list[str], - selected_percentiles: list[str], + selected_percentiles: list[float], ignore_eos: bool, goodput_config_dict: dict[str, float], max_concurrency: Optional[int], - lora_modules: Optional[list[str]], + lora_modules: Optional[Iterable[str]], ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -577,12 +268,16 @@ async def benchmark( raise ValueError(f"Unknown backend: {backend}") print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len, test_mm_content = ( - input_requests[0]) + test_prompt, test_prompt_len, test_output_len, test_mm_content = \ + input_requests[0].prompt, input_requests[0].prompt_len, \ + input_requests[0].expected_output_len, \ + input_requests[0].multi_modal_data + if backend != "openai-chat" and test_mm_content is not None: # multi-modal benchmark is only available on OpenAI Chat backend. raise ValueError( "Multi-modal content is only supported on 'openai-chat' backend.") + assert test_mm_content is None or isinstance(test_mm_content, dict) test_input = RequestFuncInput( model=model_id, model_name=model_name, @@ -606,7 +301,8 @@ async def benchmark( if lora_modules: # For each input request, choose a LoRA module at random. lora_modules = iter( - [random.choice(lora_modules) for _ in range(len(input_requests))]) + [random.choice(lora_modules) \ + for _ in range(len(input_requests))]) if profile: print("Starting profiler...") @@ -652,7 +348,9 @@ async def benchmark( benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] async for request in get_request(input_requests, request_rate, burstiness): - prompt, prompt_len, output_len, mm_content = request + prompt, prompt_len, output_len, mm_content = request.prompt, \ + request.prompt_len, request.expected_output_len, \ + request.multi_modal_data req_model_id, req_model_name = model_id, model_name if lora_modules: req_lora_module = next(lora_modules) @@ -867,76 +565,72 @@ def main(args: argparse.Namespace): "Please specify '--dataset-name' and the corresponding " "'--dataset-path' if required.") - elif args.dataset_name == "sharegpt": - input_requests = sample_sharegpt_requests( - dataset_path=args.dataset_path, - num_requests=args.num_prompts, - tokenizer=tokenizer, - fixed_output_len=args.sharegpt_output_len, - ) - - elif args.dataset_name == "burstgpt": - input_requests = sample_burstgpt_requests( - dataset_path=args.dataset_path, - num_requests=args.num_prompts, - random_seed=args.seed, - tokenizer=tokenizer, - ) - - elif args.dataset_name == "sonnet": - # Do not format the prompt, pass to message directly + if args.dataset_name == "sonnet": + dataset = SonnetDataset(dataset_path=args.dataset_path) + # For the "sonnet" dataset, formatting depends on the backend. if args.backend == "openai-chat": - input_requests = sample_sonnet_requests( - dataset_path=args.dataset_path, - num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - ) - input_requests = [(prompt, prompt_len, output_len, None) - for prompt, prompt_formatted, prompt_len, - output_len, _ in input_requests] + input_requests = dataset.sample(num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=False) else: - assert ( - tokenizer.chat_template or tokenizer.default_chat_template - ), "Tokenizer/model must have chat template for sonnet dataset." - input_requests = sample_sonnet_requests( - dataset_path=args.dataset_path, - num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - ) - input_requests = [(prompt_formatted, prompt_len, output_len, None) - for prompt, prompt_formatted, prompt_len, - output_len, _ in input_requests] + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + input_requests = dataset.sample(num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=True) elif args.dataset_name == "hf": - input_requests = sample_hf_requests( + # Choose between VisionArenaDataset + # and HuggingFaceDataset based on provided parameters. + dataset_class = (VisionArenaDataset if args.dataset_path + == VisionArenaDataset.VISION_ARENA_DATASET_PATH + and args.hf_subset is None else HuggingFaceDataset) + input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, dataset_split=args.hf_split, + ).sample( num_requests=args.num_prompts, tokenizer=tokenizer, random_seed=args.seed, - fixed_output_len=args.hf_output_len, - ) - - elif args.dataset_name == "random": - input_requests = sample_random_requests( - prefix_len=args.random_prefix_len, - input_len=args.random_input_len, - output_len=args.random_output_len, - num_prompts=args.num_prompts, - range_ratio=args.random_range_ratio, - tokenizer=tokenizer, + output_len=args.hf_output_len, ) else: - raise ValueError(f"Unknown dataset: {args.dataset_name}") + # For datasets that follow a similar structure, use a mapping. + dataset_mapping = { + "sharegpt": + lambda: ShareGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), + "burstgpt": + lambda: BurstGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path). + sample(tokenizer=tokenizer, num_requests=args.num_prompts), + "random": + lambda: RandomDataset(dataset_path=args.dataset_path).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + input_len=args.random_input_len, + output_len=args.random_output_len, + range_ratio=args.random_range_ratio, + ) + } + try: + input_requests = dataset_mapping[args.dataset_name]() + except KeyError as err: + raise ValueError(f"Unknown dataset: {args.dataset_name}") from err goodput_config_dict = check_goodput_args(args) # Avoid GC processing "static" data - reduce pause times. @@ -1298,4 +992,5 @@ if __name__ == "__main__": "script chooses a LoRA module at random.") args = parser.parse_args() + main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 4ab82447..7e655673 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -6,13 +6,14 @@ import json import os import random import time -from functools import cache +import warnings from typing import Any, Optional, Union import torch import uvloop +from benchmark_dataset import (BurstGPTDataset, RandomDataset, SampleRequest, + ShareGPTDataset, SonnetDataset) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json -from PIL import Image from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) @@ -22,148 +23,10 @@ from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest -from vllm.lora.utils import get_adapter_absolute_path -from vllm.multimodal import MultiModalDataDict from vllm.sampling_params import BeamSearchParams -from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer from vllm.utils import FlexibleArgumentParser, merge_async_iterators -@dataclasses.dataclass -class SampleRequest: - """A class representing a single inference request for benchmarking. - - Attributes: - prompt: The input text prompt for the model. - prompt_len: The length of the prompt in tokens. - expected_output_len: The expected length of the output in tokens. - multi_modal_data: Optional dictionary containing multi-modal data (e.g. - images). - lora_request: Optional LoRARequest specifying the LoRA to use. - """ - prompt: str - prompt_len: int - expected_output_len: int - multi_modal_data: Optional[MultiModalDataDict] = None - lora_request: Optional[LoRARequest] = None - - -def _get_prompt_for_image_model(question: str, *, model: str) -> str: - """Prepend and append special tokens around the question to form a prompt. - - Args: - question: The input question text to wrap with special tokens - model: The name of the model being used, to determine which special - tokens to add - - Returns: - The formatted prompt string with appropriate special tokens for the - model - - Raises: - ValueError: If an unsupported model name is provided - """ - model = model.lower() - if "pixtral" in model: - return f"[INST]{question}\n[IMG][/INST]" - raise ValueError(f"Unsupported model {model}") - - -@cache -def lora_path_on_disk(lora_path: str) -> str: - return get_adapter_absolute_path(lora_path) - - -lora_tokenizer_cache: dict[int, AnyTokenizer] = {} - - -def get_random_lora_request( - args: argparse.Namespace -) -> tuple[LoRARequest, Optional[AnyTokenizer]]: - global lora_tokenizer_cache - lora_id = random.randint(1, args.max_loras) - lora_request = LoRARequest(lora_name=str(lora_id), - lora_int_id=lora_id, - lora_path=lora_path_on_disk(args.lora_path)) - if lora_id not in lora_tokenizer_cache: - lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) - return lora_request, lora_tokenizer_cache[lora_id] - - -def sample_requests(tokenizer: PreTrainedTokenizerBase, - args: argparse.Namespace) -> list[SampleRequest]: - - dataset_path: str = args.dataset - num_requests: int = args.num_prompts - fixed_output_len: Optional[int] = args.output_len - model: str = args.model - if fixed_output_len is not None and fixed_output_len < 4: - raise ValueError("output_len too small") - - # Load the dataset. - with open(dataset_path) as f: - dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Shuffle the dataset. - random.shuffle(dataset) - - # Filter out sequences that are too long or too short - filtered_dataset: list[SampleRequest] = [] - for data in tqdm(dataset, - total=len(filtered_dataset), - desc="sampling requests"): - if len(filtered_dataset) == num_requests: - break - - # Only keep the first two turns of each conversation. - prompt = data["conversations"][0]["value"] - completion = data["conversations"][1]["value"] - - multi_modal_data: Optional[MultiModalDataDict] = None - if "image" in data: - multi_modal_data = multi_modal_data or {} - image_path = data["image"] - # TODO(vllm-project/vllm/issues/9778): Support multiple images. - assert isinstance(image_path, - str), "Only support single image input" - try: - multi_modal_data["image"] = Image.open(image_path).convert( - "RGB") - except FileNotFoundError: - # Ignore datapoint where asset is missing - continue - prompt = _get_prompt_for_image_model(question=prompt, model=model) - - request_tokenizer = tokenizer - lora_request: Optional[LoRARequest] = None - if args.enable_lora: - lora_request, lora_tokenizer = get_random_lora_request(args) - if lora_tokenizer: - request_tokenizer = lora_tokenizer - - # Tokenize the prompts and completions. - prompt_token_ids = request_tokenizer(prompt).input_ids - completion_token_ids = request_tokenizer(completion).input_ids - prompt_len = len(prompt_token_ids) - output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. - continue - filtered_dataset.append( - SampleRequest(prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - multi_modal_data=multi_modal_data, - lora_request=lora_request)) - - return filtered_dataset - - def run_vllm( requests: list[SampleRequest], n: int, @@ -381,61 +244,50 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, write_to_json(pt_file, pt_records) +def get_requests(args, tokenizer): + # Common parameters for all dataset types. + common_kwargs = { + "dataset_path": args.dataset_path, + "random_seed": args.seed, + } + sample_kwargs = { + "tokenizer": tokenizer, + "lora_path": args.lora_path, + "max_loras": args.max_loras, + "num_requests": args.num_prompts, + "input_len": args.input_len, + "output_len": args.output_len, + } + if args.dataset_path is None or args.dataset_name == "random": + sample_kwargs["range_ratio"] = args.random_range_ratio + sample_kwargs["prefix_len"] = args.prefix_len + dataset_cls = RandomDataset + elif args.dataset_name == "sharegpt": + dataset_cls = ShareGPTDataset + elif args.dataset_name == "sonnet": + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + dataset_cls = SonnetDataset + sample_kwargs["prefix_len"] = args.prefix_len + sample_kwargs["return_prompt_formatted"] = True + elif args.dataset_name == "burstgpt": + dataset_cls = BurstGPTDataset + else: + raise ValueError(f"Unknown dataset name: {args.dataset_name}") + # Remove None values + sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} + return dataset_cls(**common_kwargs).sample(**sample_kwargs) + + def main(args: argparse.Namespace): + if args.seed is None: + args.seed = 0 print(args) random.seed(args.seed) - # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( args.tokenizer, trust_remote_code=args.trust_remote_code) - if args.dataset is None: - vocab_size = tokenizer.vocab_size - requests = [] - for _ in range(args.num_prompts): - - request_tokenizer = tokenizer - lora_request: Optional[LoRARequest] = None - if args.enable_lora: - lora_request, lora_tokenizer = get_random_lora_request(args) - if lora_tokenizer: - request_tokenizer = lora_tokenizer - - # Synthesize a prompt with the given input length. - candidate_ids = [ - random.randint(0, vocab_size - 1) - for _ in range(args.input_len) - ] - - candidate_prompt = {"prompt_token_ids": candidate_ids} - - if not args.skip_tokenizer_init: - # As tokenizer may add additional tokens like BOS, we need - # to try different lengths to get the desired input length. - for _ in range(5): # Max attempts to correct - candidate_prompt = request_tokenizer.decode(candidate_ids) - tokenized_len = len( - request_tokenizer.encode(candidate_prompt)) - - if tokenized_len == args.input_len: - break - - # Adjust length based on difference - diff = args.input_len - tokenized_len - if diff > 0: - candidate_ids.extend([ - random.randint(100, vocab_size - 100) - for _ in range(diff) - ]) - else: - candidate_ids = candidate_ids[:diff] - requests.append( - SampleRequest(prompt=candidate_prompt, - prompt_len=args.input_len, - expected_output_len=args.output_len, - lora_request=lora_request)) - else: - requests = sample_requests(tokenizer, args) - + requests = get_requests(args, tokenizer) is_multi_modal = any(request.multi_modal_data is not None for request in requests) if args.backend == "vllm": @@ -470,7 +322,7 @@ def main(args: argparse.Namespace): print("\033[91mWARNING\033[0m: Multi-modal request detected. The " "following metrics are not accurate because image tokens are not" " counted. See vllm-project/vllm/issues/9778 for details.") - # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length. + # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " f"{total_output_tokens / elapsed_time:.2f} output tokens/s") @@ -495,12 +347,23 @@ if __name__ == "__main__": type=str, choices=["vllm", "hf", "mii"], default="vllm") - parser.add_argument("--dataset", + parser.add_argument("--dataset-name", + type=str, + choices=["sharegpt", "random", "sonnet", "burstgpt"], + help="Name of the dataset to benchmark on.", + default="sharegpt") + parser.add_argument( + "--dataset", + type=str, + default=None, + help="Path to the ShareGPT dataset, will be deprecated in\ + the next release. The dataset is expected to " + "be a json in form of list[dict[..., conversations: " + "list[dict[..., value: ]]]]") + parser.add_argument("--dataset-path", type=str, default=None, - help="Path to the dataset. The dataset is expected to " - "be a json in form of list[dict[..., conversations: " - "list[dict[..., value: ]]]]") + help="Path to the dataset") parser.add_argument("--input-len", type=int, default=None, @@ -547,14 +410,35 @@ if __name__ == "__main__": default=None, help="Path to the lora adapters to use. This can be an absolute path, " "a relative path, or a Hugging Face model identifier.") + parser.add_argument("--prefix-len", + type=int, + default=None, + help="Number of prefix tokens per request." + "This is for the RandomDataset and SonnetDataset") + # random dataset + parser.add_argument( + "--random-range-ratio", + type=float, + default=1.0, + help="Range of sampled ratio of input/output length, " + "used only for RandomDataSet.", + ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model - if args.dataset is None: - assert args.input_len is not None - assert args.output_len is not None + if args.dataset is not None: + warnings.warn( + "The '--dataset' argument will be deprecated in the next " + "release. Please use '--dataset-name' and " + "'--dataset-path' in the future runs.", + stacklevel=2) + args.dataset_path = args.dataset + if args.dataset is None and args.dataset_path is None: + # for random dataset, the default sampling setting is in + # benchmark_dataset.RandomDataset + print("When dataset is not set, it will default to random dataset") else: assert args.input_len is None if args.enable_lora: