From b4fe16c75b437794900afcc3e1aa53df34e5ea38 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 15 Apr 2025 00:10:35 -0600 Subject: [PATCH] Add `vllm bench [latency, throughput]` CLI commands (#16508) Signed-off-by: mgoin --- .buildkite/test-pipeline.yaml | 7 + tests/benchmarks/__init__.py | 0 tests/benchmarks/test_latency_cli.py | 19 + tests/benchmarks/test_serve_cli.py | 44 + tests/benchmarks/test_throughput_cli.py | 19 + vllm/benchmarks/datasets.py | 831 +++++++++++++++++++ vllm/benchmarks/latency.py | 181 ++++ vllm/benchmarks/throughput.py | 608 ++++++++++++++ vllm/entrypoints/cli/benchmark/latency.py | 29 + vllm/entrypoints/cli/benchmark/main.py | 6 +- vllm/entrypoints/cli/benchmark/throughput.py | 29 + 11 files changed, 1771 insertions(+), 2 deletions(-) create mode 100644 tests/benchmarks/__init__.py create mode 100644 tests/benchmarks/test_latency_cli.py create mode 100644 tests/benchmarks/test_serve_cli.py create mode 100644 tests/benchmarks/test_throughput_cli.py create mode 100644 vllm/benchmarks/datasets.py create mode 100644 vllm/benchmarks/latency.py create mode 100644 vllm/benchmarks/throughput.py create mode 100644 vllm/entrypoints/cli/benchmark/latency.py create mode 100644 vllm/entrypoints/cli/benchmark/throughput.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 16acc2fd..38961138 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -341,6 +341,13 @@ steps: commands: - bash scripts/run-benchmarks.sh +- label: Benchmarks CLI Test # 10min + source_file_dependencies: + - vllm/ + - tests/benchmarks/ + commands: + - pytest -v -s benchmarks/ + - label: Quantization Test # 33min source_file_dependencies: - csrc/ diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/benchmarks/test_latency_cli.py b/tests/benchmarks/test_latency_cli.py new file mode 100644 index 00000000..8537459b --- /dev/null +++ b/tests/benchmarks/test_latency_cli.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +import subprocess + +import pytest + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + + +@pytest.mark.benchmark +def test_bench_latency(): + command = [ + "vllm", "bench", "latency", "--model", MODEL_NAME, "--input-len", "32", + "--output-len", "1", "--enforce-eager", "--load-format", "dummy" + ] + result = subprocess.run(command, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f"Benchmark failed: {result.stderr}" diff --git a/tests/benchmarks/test_serve_cli.py b/tests/benchmarks/test_serve_cli.py new file mode 100644 index 00000000..b746d6b7 --- /dev/null +++ b/tests/benchmarks/test_serve_cli.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +import subprocess + +import pytest + +from ..utils import RemoteOpenAIServer + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.benchmark +def test_bench_serve(server): + command = [ + "vllm", + "bench", + "serve", + "--model", + MODEL_NAME, + "--host", + server.host, + "--port", + str(server.port), + "--random-input-len", + "32", + "--random-output-len", + "4", + "--num-prompts", + "5", + ] + result = subprocess.run(command, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f"Benchmark failed: {result.stderr}" diff --git a/tests/benchmarks/test_throughput_cli.py b/tests/benchmarks/test_throughput_cli.py new file mode 100644 index 00000000..2045b362 --- /dev/null +++ b/tests/benchmarks/test_throughput_cli.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +import subprocess + +import pytest + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + + +@pytest.mark.benchmark +def test_bench_throughput(): + command = [ + "vllm", "bench", "throughput", "--model", MODEL_NAME, "--input-len", + "32", "--output-len", "1", "--enforce-eager", "--load-format", "dummy" + ] + result = subprocess.run(command, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f"Benchmark failed: {result.stderr}" diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py new file mode 100644 index 00000000..299c888c --- /dev/null +++ b/vllm/benchmarks/datasets.py @@ -0,0 +1,831 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This module defines a framework for sampling benchmark requests from various +datasets. Each dataset subclass of BenchmarkDataset must implement sample +generation. Supported dataset types include: + - ShareGPT + - Random (synthetic) + - Sonnet + - BurstGPT + - HuggingFace + - VisionArena + +TODO: Implement CustomDataset to parse a JSON file and convert its contents into +SampleRequest instances, similar to the approach used in ShareGPT. +""" + +import base64 +import io +import json +import logging +import random +from abc import ABC, abstractmethod +from collections.abc import Mapping +from dataclasses import dataclass +from functools import cache +from io import BytesIO +from typing import Any, Callable, Optional, Union + +import numpy as np +from PIL import Image +from transformers import PreTrainedTokenizerBase + +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path +from vllm.multimodal import MultiModalDataDict +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer + +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Data Classes +# ----------------------------------------------------------------------------- + + +@dataclass +class SampleRequest: + """ + Represents a single inference request for benchmarking. + """ + + prompt: Union[str, Any] + prompt_len: int + expected_output_len: int + multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None + lora_request: Optional[LoRARequest] = None + + +# ----------------------------------------------------------------------------- +# Benchmark Dataset Base Class +# ----------------------------------------------------------------------------- + + +class BenchmarkDataset(ABC): + DEFAULT_SEED = 0 + + def __init__( + self, + dataset_path: Optional[str] = None, + random_seed: int = DEFAULT_SEED, + ) -> None: + """ + Initialize the BenchmarkDataset with an optional dataset path and random + seed. + + Args: + dataset_path (Optional[str]): Path to the dataset. If None, it + indicates that a default or random dataset might be used. + random_seed (int): Seed value for reproducible shuffling or + sampling. Defaults to DEFAULT_SEED. + """ + self.dataset_path = dataset_path + # Set the random seed, ensuring that a None value is replaced with the + # default seed. + self.random_seed = (random_seed + if random_seed is not None else self.DEFAULT_SEED) + self.data = None + + def apply_multimodal_chat_transformation( + self, + prompt: str, + mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + """ + Transform a prompt and optional multimodal content into a chat format. + This method is used for chat models that expect a specific conversation + format. + """ + content = [{"text": prompt, "type": "text"}] + if mm_content is not None: + content.append(mm_content) + return [{"role": "user", "content": content}] + + def load_data(self) -> None: + """ + Load data from the dataset path into self.data. + + This method must be overridden by subclasses since the method to load + data will vary depending on the dataset format and source. + + Raises: + NotImplementedError: If a subclass does not implement this method. + """ + # TODO (jenniferzhao): add support for downloading data + raise NotImplementedError( + "load_data must be implemented in subclasses.") + + def get_random_lora_request( + self, + tokenizer: PreTrainedTokenizerBase, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + ) -> tuple[Optional[LoRARequest], AnyTokenizer]: + """ + Optionally select a random LoRA request and return its associated + tokenizer. + + This method is used when LoRA parameters are provided. It randomly + selects a LoRA based on max_loras and retrieves a cached tokenizer for + that LoRA if available. Otherwise, it returns the base tokenizer. + + Args: + tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no + LoRA is selected. max_loras (Optional[int]): The maximum number of + LoRAs available. If None, LoRA is not used. lora_path + (Optional[str]): Path to the LoRA parameters on disk. If None, LoRA + is not used. + + Returns: + tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first + element is a LoRARequest (or None if not applicable) and the second + element is the tokenizer associated with the LoRA request (or the + base tokenizer). + """ + if max_loras is None or lora_path is None: + return None, tokenizer + + # Generate a random LoRA ID in the range [1, max_loras]. + lora_id = random.randint(1, max_loras) + lora_request = LoRARequest( + lora_name=str(lora_id), + lora_int_id=lora_id, + lora_path=lora_path_on_disk(lora_path), + ) + if lora_id not in lora_tokenizer_cache: + lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) + # Return lora_request and the cached tokenizer if available; otherwise, + # return the base tokenizer + return lora_request, lora_tokenizer_cache[lora_id] or tokenizer + + @abstractmethod + def sample(self, tokenizer: PreTrainedTokenizerBase, + num_requests: int) -> list[SampleRequest]: + """ + Abstract method to generate sample requests from the dataset. + + Subclasses must override this method to implement dataset-specific logic + for generating a list of SampleRequest objects. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used + for processing the dataset's text. + num_requests (int): The number of sample requests to generate. + + Returns: + list[SampleRequest]: A list of sample requests generated from the + dataset. + """ + raise NotImplementedError("sample must be implemented in subclasses.") + + def maybe_oversample_requests(self, requests: list[SampleRequest], + num_requests: int) -> None: + """ + Oversamples the list of requests if its size is less than the desired + number. + + Args: + requests (List[SampleRequest]): The current list of sampled + requests. num_requests (int): The target number of requests. + """ + if len(requests) < num_requests: + random.seed(self.random_seed) + additional = random.choices(requests, + k=num_requests - len(requests)) + requests.extend(additional) + logger.info("Oversampled requests to reach %d total samples.", + num_requests) + + +# ----------------------------------------------------------------------------- +# Utility Functions and Global Caches +# ----------------------------------------------------------------------------- + + +def is_valid_sequence( + prompt_len: int, + output_len: int, + min_len: int = 4, + max_prompt_len: int = 1024, + max_total_len: int = 2048, + skip_min_output_len_check: bool = False, +) -> bool: + """ + Validate a sequence based on prompt and output lengths. + + Default pruning criteria are copied from the original `sample_hf_requests` + and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as + from `sample_requests` in benchmark_throughput.py. + """ + # Check for invalid conditions + prompt_too_short = prompt_len < min_len + output_too_short = (not skip_min_output_len_check) and (output_len + < min_len) + prompt_too_long = prompt_len > max_prompt_len + combined_too_long = (prompt_len + output_len) > max_total_len + + # Return True if none of the invalid conditions are met + return not (prompt_too_short or output_too_short or prompt_too_long + or combined_too_long) + + +@cache +def lora_path_on_disk(lora_path: str) -> str: + return get_adapter_absolute_path(lora_path) + + +# Global cache for LoRA tokenizers. +lora_tokenizer_cache: dict[int, AnyTokenizer] = {} + + +def process_image(image: Any) -> Mapping[str, Any]: + """ + Process a single image input and return a multimedia content dictionary. + + Supports three input types: + + 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key + containing raw image data. - Loads the bytes as a PIL.Image.Image. + + 2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as + a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns + a dictionary with the image as a base64 data URL. + + 3. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(image, dict) and 'bytes' in image: + image = Image.open(BytesIO(image['bytes'])) + if isinstance(image, Image.Image): + image = image.convert("RGB") + with io.BytesIO() as image_data: + image.save(image_data, format="JPEG") + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + return { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + if isinstance(image, str): + image_url = (image if image.startswith( + ("http://", "file://")) else f"file://{image}") + return {"type": "image_url", "image_url": {"url": image_url}} + + raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" + " or str or dictionary with raw image bytes.") + + +# ----------------------------------------------------------------------------- +# Random Dataset Implementation (Synthetic Data) +# ----------------------------------------------------------------------------- + + +class RandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the random dataset. + DEFAULT_PREFIX_LEN = 0 + DEFAULT_RANGE_RATIO = 0.0 + DEFAULT_INPUT_LEN = 1024 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + range_ratio: float = DEFAULT_RANGE_RATIO, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + **kwargs, + ) -> list[SampleRequest]: + # Enforce range_ratio < 1 + assert range_ratio < 1.0, ( + "random_range_ratio must be < 1.0 to ensure a valid sampling range" + ) + + vocab_size = tokenizer.vocab_size + + prefix_token_ids = (np.random.randint( + 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + + # New sampling logic: [X * (1 - b), X * (1 + b)] + input_low = int(input_len * (1 - range_ratio)) + input_high = int(input_len * (1 + range_ratio)) + output_low = int(output_len * (1 - range_ratio)) + output_high = int(output_len * (1 + range_ratio)) + + # Add logging for debugging + logger.info("Sampling input_len from [%s, %s]", input_low, input_high) + logger.info("Sampling output_len from [%s, %s]", output_low, + output_high) + + input_lens = np.random.randint(input_low, + input_high + 1, + size=num_requests) + output_lens = np.random.randint(output_low, + output_high + 1, + size=num_requests) + offsets = np.random.randint(0, vocab_size, size=num_requests) + + requests = [] + for i in range(num_requests): + inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % + vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + prompt = tokenizer.decode(token_sequence) + total_input_len = prefix_len + int(input_lens[i]) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + )) + return requests + + +# ----------------------------------------------------------------------------- +# ShareGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ShareGPTDataset(BenchmarkDataset): + """ + Implements the ShareGPT dataset. Loads data from a JSON file and generates + sample requests based on conversation turns. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + with open(self.dataset_path, encoding="utf-8") as f: + self.data = json.load(f) + # Filter entries with at least two conversation turns. + self.data = [ + entry for entry in self.data + if "conversations" in entry and len(entry["conversations"]) >= 2 + ] + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + samples: list = [] + for entry in self.data: + if len(samples) >= num_requests: + break + prompt, completion = ( + entry["conversations"][0]["value"], + entry["conversations"][1]["value"], + ) + + lora_request, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + new_output_len = (len(completion_ids) + if output_len is None else output_len) + if not is_valid_sequence(prompt_len, + new_output_len, + skip_min_output_len_check=output_len + is not None): + continue + if enable_multimodal_chat: + prompt = self.apply_multimodal_chat_transformation( + prompt, None) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=new_output_len, + lora_request=lora_request, + )) + self.maybe_oversample_requests(samples, num_requests) + return samples + + +# ----------------------------------------------------------------------------- +# Sonnet Dataset Implementation +# ----------------------------------------------------------------------------- + + +class SonnetDataset(BenchmarkDataset): + """ + Simplified implementation of the Sonnet dataset. Loads poem lines from a + text file and generates sample requests. Default values here copied from + `benchmark_serving.py` for the sonnet dataset. + """ + + DEFAULT_PREFIX_LEN = 200 + DEFAULT_INPUT_LEN = 550 + DEFAULT_OUTPUT_LEN = 150 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if not self.dataset_path: + raise ValueError("dataset_path must be provided.") + with open(self.dataset_path, encoding="utf-8") as f: + self.data = f.readlines() + + def sample( + self, + tokenizer, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + return_prompt_formatted: bool = False, + **kwargs, + ) -> list: + # Calculate average token length for a poem line. + tokenized_lines = [tokenizer(line).input_ids for line in self.data] + avg_len = sum(len(tokens) + for tokens in tokenized_lines) / len(tokenized_lines) + + # Build the base prompt. + base_prompt = "Pick as many lines as you can from these poem lines:\n" + base_msg = [{"role": "user", "content": base_prompt}] + base_fmt = tokenizer.apply_chat_template(base_msg, + add_generation_prompt=True, + tokenize=False) + base_offset = len(tokenizer(base_fmt).input_ids) + if input_len <= base_offset: + raise ValueError( + f"'input_len' must be higher than the base prompt length " + f"({base_offset}).") + + # Determine how many poem lines to use. + num_input_lines = round((input_len - base_offset) / avg_len) + num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0) + prefix_lines = self.data[:num_prefix_lines] + + samples = [] + while len(samples) < num_requests: + extra_lines = random.choices(self.data, + k=num_input_lines - num_prefix_lines) + prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" + msg = [{"role": "user", "content": prompt}] + prompt_formatted = tokenizer.apply_chat_template( + msg, add_generation_prompt=True, tokenize=False) + prompt_len = len(tokenizer(prompt_formatted).input_ids) + if prompt_len <= input_len: + samples.append( + SampleRequest( + prompt=prompt_formatted + if return_prompt_formatted else prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + return samples + + +# ----------------------------------------------------------------------------- +# BurstGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class BurstGPTDataset(BenchmarkDataset): + """ + Implements the BurstGPT dataset. Loads data from a CSV file and generates + sample requests based on synthetic prompt generation. Only rows with Model + "GPT-4" and positive response tokens are used. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self, ): + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + try: + import pandas as pd + except ImportError as e: + raise ImportError( + "Pandas is required for BurstGPTDataset. Please install it " + "using `pip install pandas`.") from e + + df = pd.read_csv(self.dataset_path) + # Filter to keep only GPT-4 rows. + gpt4_df = df[df["Model"] == "GPT-4"] + # Remove failed requests (where Response tokens is 0 or less). + gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] + # Sample the desired number of rows. + self.data = gpt4_df + + def _sample_loaded_data(self, num_requests: int) -> list: + if num_requests <= len(self.data): + data = self.data.sample(n=num_requests, + random_state=self.random_seed) + else: + data = self.data.sample( + n=num_requests, + random_state=self.random_seed, + replace=True, + ) + # Convert the dataframe to a list of lists. + return data.values.tolist() + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + **kwargs, + ) -> list[SampleRequest]: + samples = [] + data = self._sample_loaded_data(num_requests=num_requests) + for i in range(num_requests): + input_len = int(data[i][2]) + output_len = int(data[i][3]) + lora_req, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + vocab_size = tokenizer.vocab_size + # Generate a synthetic prompt: a list of token IDs computed as (i + + # j) modulo vocab_size. + token_ids = [(i + j) % vocab_size for j in range(input_len)] + prompt = tokenizer.decode(token_ids) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=output_len, + lora_request=lora_req, + )) + return samples + + +# ----------------------------------------------------------------------------- +# HuggingFace Dataset Base Implementation +# ----------------------------------------------------------------------------- +class HuggingFaceDataset(BenchmarkDataset): + """Base class for datasets hosted on HuggingFace.""" + + SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set() + + def __init__( + self, + dataset_path: str, + dataset_split: str, + dataset_subset: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(dataset_path=dataset_path, **kwargs) + + self.dataset_split = dataset_split + self.dataset_subset = dataset_subset + self.load_data() + + def load_data(self) -> None: + """Load data from HuggingFace datasets.""" + try: + from datasets import load_dataset + except ImportError as e: + raise ImportError( + "Hugging Face datasets library is required for this dataset. " + "Please install it using `pip install datasets`.") from e + + self.data = load_dataset( + self.dataset_path, + name=self.dataset_subset, + split=self.dataset_split, + streaming=True, + ) + self.data = self.data.shuffle(seed=self.random_seed) + + +# ----------------------------------------------------------------------------- +# Conversation Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ConversationDataset(HuggingFaceDataset): + """Dataset for conversation data with multimodal support.""" + SUPPORTED_DATASET_PATHS = { + 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + # Filter examples with at least 2 conversations + filtered_data = self.data.filter( + lambda x: len(x["conversations"]) >= 2) + sampled_requests = [] + dynamic_output = output_len is None + + for item in filtered_data: + if len(sampled_requests) >= num_requests: + break + conv = item["conversations"] + prompt, completion = conv[0]["value"], conv[1]["value"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence( + prompt_len, completion_len): + continue + mm_content = process_image( + item["image"]) if "image" in item else None + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len and output len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Vision Arena Dataset Implementation +# ----------------------------------------------------------------------------- + + +class VisionArenaDataset(HuggingFaceDataset): + """ + Vision Arena Dataset. + """ + + DEFAULT_OUTPUT_LEN = 128 + SUPPORTED_DATASET_PATHS = { + "lmarena-ai/VisionArena-Chat": + lambda x: x["conversation"][0][0]["content"], + "lmarena-ai/vision-arena-bench-v0.1": + lambda x: x["turns"][0][0]["content"] + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) + if parser_fn is None: + raise ValueError( + f"Unsupported dataset path: {self.dataset_path}") + prompt = parser_fn(item) + mm_content = process_image(item["images"][0]) + prompt_len = len(tokenizer(prompt).input_ids) + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Instruct Coder Dataset Implementation +# ----------------------------------------------------------------------------- + + +class InstructCoderDataset(HuggingFaceDataset): + """ + InstructCoder Dataset. + https://huggingface.co/datasets/likaixin/InstructCoder + + InstructCoder is the dataset designed for general code editing. It consists + of 114,239 instruction-input-output triplets, and covers multiple distinct + code editing scenario. + """ + + DEFAULT_OUTPUT_LEN = 200 # this is the average default output length + SUPPORTED_DATASET_PATHS = { + "likaixin/InstructCoder", + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = f"{item['instruction']}:\n{item['input']}" + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# AIMO Dataset Implementation +# ----------------------------------------------------------------------------- + + +class AIMODataset(HuggingFaceDataset): + """ + Dataset class for processing a AIMO dataset with reasoning questions. + """ + SUPPORTED_DATASET_PATHS = { + "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", + "AI-MO/NuminaMath-CoT" + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs) -> list: + sampled_requests = [] + dynamic_output = output_len is None + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt, completion = item['problem'], item["solution"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence(prompt_len, + completion_len, + max_prompt_len=2048, + max_total_len=32000): + continue + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=None, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py new file mode 100644 index 00000000..06f6848f --- /dev/null +++ b/vllm/benchmarks/latency.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark the latency of processing a single batch of requests.""" + +import argparse +import dataclasses +import json +import os +import time +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import torch +from tqdm import tqdm + +from vllm import LLM, SamplingParams +from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.engine.arg_utils import EngineArgs +from vllm.inputs import PromptType +from vllm.sampling_params import BeamSearchParams + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any]) -> None: + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={"latency": results["latencies"]}, + extra_info={k: results[k] + for k in ["avg_latency", "percentiles"]}) + if pt_records: + pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--input-len", type=int, default=32) + parser.add_argument("--output-len", type=int, default=128) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument( + "--n", + type=int, + default=1, + help="Number of generated sequences per prompt.", + ) + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument( + "--num-iters-warmup", + type=int, + default=10, + help="Number of iterations to run for warmup.", + ) + parser.add_argument("--num-iters", + type=int, + default=30, + help="Number of iterations to run.") + parser.add_argument( + "--profile", + action="store_true", + help="profile the generation process of a single batch", + ) + parser.add_argument( + "--profile-result-dir", + type=str, + default=None, + help=("path to save the pytorch profiler output. Can be visualized " + "with ui.perfetto.dev or Tensorboard."), + ) + parser.add_argument( + "--output-json", + type=str, + default=None, + help="Path to save the latency results in JSON format.", + ) + parser.add_argument( + "--disable-detokenize", + action="store_true", + help=("Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)"), + ) + + parser = EngineArgs.add_cli_args(parser) + + +def main(args: argparse.Namespace): + print(args) + + engine_args = EngineArgs.from_cli_args(args) + + # NOTE(woosuk): If the request cannot be processed in a single batch, + # the engine will automatically process the request in multiple batches. + llm = LLM(**dataclasses.asdict(engine_args)) + assert llm.llm_engine.model_config.max_model_len >= ( + args.input_len + + args.output_len), ("Please ensure that max_model_len is greater than" + " the sum of input_len and output_len.") + + sampling_params = SamplingParams( + n=args.n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=args.output_len, + detokenize=not args.disable_detokenize, + ) + print(sampling_params) + dummy_prompt_token_ids = np.random.randint(10000, + size=(args.batch_size, + args.input_len)) + dummy_prompts: list[PromptType] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] + + def llm_generate(): + if not args.use_beam_search: + llm.generate(dummy_prompts, + sampling_params=sampling_params, + use_tqdm=False) + else: + llm.beam_search( + dummy_prompts, + BeamSearchParams( + beam_width=args.n, + max_tokens=args.output_len, + ignore_eos=True, + ), + ) + + def run_to_completion(profile_dir: Optional[str] = None): + if profile_dir: + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler( + str(profile_dir)), + ) as p: + llm_generate() + print(p.key_averages().table(sort_by="self_cuda_time_total")) + else: + start_time = time.perf_counter() + llm_generate() + end_time = time.perf_counter() + latency = end_time - start_time + return latency + + print("Warming up...") + for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): + run_to_completion(profile_dir=None) + + if args.profile: + profile_dir = args.profile_result_dir + if not profile_dir: + profile_dir = (Path(".") / "vllm_benchmark_result" / + f"latency_result_{time.time()}") + print(f"Profiling (results will be saved to '{profile_dir}')...") + run_to_completion(profile_dir=profile_dir) + return + + # Benchmark. + latencies = [] + for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): + latencies.append(run_to_completion(profile_dir=None)) + latencies = np.array(latencies) + percentages = [10, 25, 50, 75, 90, 99] + percentiles = np.percentile(latencies, percentages) + print(f"Avg latency: {np.mean(latencies)} seconds") + for percentage, percentile in zip(percentages, percentiles): + print(f"{percentage}% percentile latency: {percentile} seconds") + + # Output JSON results if specified + if args.output_json: + results = { + "avg_latency": np.mean(latencies), + "latencies": latencies.tolist(), + "percentiles": dict(zip(percentages, percentiles.tolist())), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py new file mode 100644 index 00000000..b3e24911 --- /dev/null +++ b/vllm/benchmarks/throughput.py @@ -0,0 +1,608 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark offline inference throughput.""" +import argparse +import dataclasses +import json +import os +import random +import time +import warnings +from typing import Any, Optional, Union + +import torch +import uvloop +from tqdm import tqdm +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizerBase) + +from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset, + ConversationDataset, + InstructCoderDataset, RandomDataset, + SampleRequest, ShareGPTDataset, + SonnetDataset, VisionArenaDataset) +from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams +from vllm.utils import merge_async_iterators + + +def run_vllm( + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False, +) -> tuple[float, Optional[list[RequestOutput]]]: + from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") + # Add the requests to the engine. + prompts: list[Union[TextPrompt, TokensPrompt]] = [] + sampling_params: list[SamplingParams] = [] + for request in requests: + prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + lora_requests: Optional[list[LoRARequest]] = None + if engine_args.enable_lora: + lora_requests = [request.lora_request for request in requests] + + use_beam_search = False + + outputs = None + if not use_beam_search: + start = time.perf_counter() + outputs = llm.generate(prompts, + sampling_params, + lora_request=lora_requests, + use_tqdm=True) + end = time.perf_counter() + else: + assert lora_requests is None, "BeamSearch API does not support LoRA" + prompts = [request.prompt for request in requests] + # output_len should be the same for all requests. + output_len = requests[0][2] + for request in requests: + assert request.expected_output_len == output_len + start = time.perf_counter() + llm.beam_search( + prompts, + BeamSearchParams( + beam_width=n, + max_tokens=output_len, + ignore_eos=True, + )) + end = time.perf_counter() + return end - start, outputs + + +def run_vllm_chat( + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: + """ + Run vLLM chat benchmark. This function is recommended ONLY for benchmarking + multimodal models as it properly handles multimodal inputs and chat + formatting. For non-multimodal models, use run_vllm() instead. + """ + from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) + + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of " + "prompt_len and expected_output_len for all requests.") + + prompts = [] + sampling_params: list[SamplingParams] = [] + for request in requests: + prompts.append(request.prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + start = time.perf_counter() + outputs = llm.chat(prompts, sampling_params, use_tqdm=True) + end = time.perf_counter() + return end - start, outputs + + +async def run_vllm_async( + requests: list[SampleRequest], + n: int, + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, + disable_detokenize: bool = False, +) -> float: + from vllm import SamplingParams + + async with build_async_engine_client_from_engine_args( + engine_args, disable_frontend_multiprocessing) as llm: + assert all( + llm.model_config.max_model_len >= (request.prompt_len + + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") + + # Add the requests to the engine. + prompts: list[Union[TextPrompt, TokensPrompt]] = [] + sampling_params: list[SamplingParams] = [] + lora_requests: list[Optional[LoRARequest]] = [] + for request in requests: + prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + lora_requests.append(request.lora_request) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp, + lr) in enumerate(zip(prompts, sampling_params, lora_requests)): + generator = llm.generate(prompt, + sp, + lora_request=lr, + request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start + + +def run_hf( + requests: list[SampleRequest], + model: str, + tokenizer: PreTrainedTokenizerBase, + n: int, + max_batch_size: int, + trust_remote_code: bool, + disable_detokenize: bool = False, +) -> float: + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if llm.config.model_type == "llama": + # To enable padding in the HF backend. + tokenizer.pad_token = tokenizer.eos_token + llm = llm.cuda() + + pbar = tqdm(total=len(requests)) + start = time.perf_counter() + batch: list[str] = [] + max_prompt_len = 0 + max_output_len = 0 + for i in range(len(requests)): + prompt = requests[i].prompt + prompt_len = requests[i].prompt_len + output_len = requests[i].expected_output_len + # Add the prompt to the batch. + batch.append(prompt) + max_prompt_len = max(max_prompt_len, prompt_len) + max_output_len = max(max_output_len, output_len) + if len(batch) < max_batch_size and i != len(requests) - 1: + # Check if we can add more requests to the batch. + next_prompt_len = requests[i + 1].prompt_len + next_output_len = requests[i + 1].expected_output_len + if (max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len)) <= 2048: + # We can add more requests to the batch. + continue + + # Generate the sequences. + input_ids = tokenizer(batch, return_tensors="pt", + padding=True).input_ids + llm_outputs = llm.generate( + input_ids=input_ids.cuda(), + do_sample=True, + num_return_sequences=n, + temperature=1.0, + top_p=1.0, + use_cache=True, + max_new_tokens=max_output_len, + ) + if not disable_detokenize: + # Include the decoding time. + tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) + pbar.update(len(batch)) + + # Clear the batch. + batch = [] + max_prompt_len = 0 + max_output_len = 0 + end = time.perf_counter() + return end - start + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any]) -> None: + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "requests_per_second": [results["requests_per_second"]], + "tokens_per_second": [results["tokens_per_second"]], + }, + extra_info={ + k: results[k] + for k in ["elapsed_time", "num_requests", "total_num_tokens"] + }) + if pt_records: + # Don't use json suffix here as we don't want CI to pick it up + pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def get_requests(args, tokenizer): + # Common parameters for all dataset types. + common_kwargs = { + "dataset_path": args.dataset_path, + "random_seed": args.seed, + } + sample_kwargs = { + "tokenizer": tokenizer, + "lora_path": args.lora_path, + "max_loras": args.max_loras, + "num_requests": args.num_prompts, + "input_len": args.input_len, + "output_len": args.output_len, + } + + if args.dataset_path is None or args.dataset_name == "random": + sample_kwargs["range_ratio"] = args.random_range_ratio + sample_kwargs["prefix_len"] = args.prefix_len + dataset_cls = RandomDataset + elif args.dataset_name == "sharegpt": + dataset_cls = ShareGPTDataset + if args.backend == "vllm-chat": + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_name == "sonnet": + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + dataset_cls = SonnetDataset + sample_kwargs["prefix_len"] = args.prefix_len + sample_kwargs["return_prompt_formatted"] = True + elif args.dataset_name == "burstgpt": + dataset_cls = BurstGPTDataset + elif args.dataset_name == "hf": + if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = VisionArenaDataset + common_kwargs['dataset_subset'] = None + common_kwargs['dataset_split'] = "train" + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = InstructCoderDataset + common_kwargs['dataset_split'] = "train" + elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = ConversationDataset + common_kwargs['dataset_subset'] = args.hf_subset + common_kwargs['dataset_split'] = args.hf_split + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + dataset_cls = AIMODataset + common_kwargs['dataset_subset'] = None + common_kwargs['dataset_split'] = "train" + else: + raise ValueError(f"Unknown dataset name: {args.dataset_name}") + # Remove None values + sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} + return dataset_cls(**common_kwargs).sample(**sample_kwargs) + + +def validate_args(args): + """ + Validate command-line arguments. + """ + + # === Deprecation and Defaulting === + if args.dataset is not None: + warnings.warn( + "The '--dataset' argument will be deprecated in the next release. " + "Please use '--dataset-name' and '--dataset-path' instead.", + stacklevel=2) + args.dataset_path = args.dataset + + if not getattr(args, "tokenizer", None): + args.tokenizer = args.model + + # === Backend Validation === + valid_backends = {"vllm", "hf", "mii", "vllm-chat"} + if args.backend not in valid_backends: + raise ValueError(f"Unsupported backend: {args.backend}") + + # === Dataset Configuration === + if not args.dataset and not args.dataset_path: + print( + "When dataset path is not set, it will default to random dataset") + args.dataset_name = 'random' + if args.input_len is None: + raise ValueError("input_len must be provided for a random dataset") + + # === Dataset Name Specific Checks === + # --hf-subset and --hf-split: only used + # when dataset_name is 'hf' + if args.dataset_name != "hf" and ( + getattr(args, "hf_subset", None) is not None + or getattr(args, "hf_split", None) is not None): + warnings.warn("--hf-subset and --hf-split will be ignored \ + since --dataset-name is not 'hf'.", + stacklevel=2) + elif args.dataset_name == "hf": + if args.dataset_path in ( + VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() + | ConversationDataset.SUPPORTED_DATASET_PATHS): + assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 + elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS + | AIMODataset.SUPPORTED_DATASET_PATHS): + assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 + else: + raise ValueError( + f"{args.dataset_path} is not supported by hf dataset.") + + # --random-range-ratio: only used when dataset_name is 'random' + if args.dataset_name != 'random' and args.random_range_ratio is not None: + warnings.warn("--random-range-ratio will be ignored since \ + --dataset-name is not 'random'.", + stacklevel=2) + + # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not + # set. + if args.dataset_name not in {"random", "sonnet", None + } and args.prefix_len is not None: + warnings.warn("--prefix-len will be ignored since --dataset-name\ + is not 'random', 'sonnet', or not set.", + stacklevel=2) + + # === LoRA Settings === + if getattr(args, "enable_lora", False) and args.backend != "vllm": + raise ValueError( + "LoRA benchmarking is only supported for vLLM backend") + if getattr(args, "enable_lora", False) and args.lora_path is None: + raise ValueError("LoRA path must be provided when enable_lora is True") + + # === Backend-specific Validations === + if args.backend == "hf" and args.hf_max_batch_size is None: + raise ValueError("HF max batch size is required for HF backend") + if args.backend != "hf" and args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + + if args.backend in {"hf", "mii"} and getattr(args, "quantization", + None) is not None: + raise ValueError("Quantization is only for vLLM backend.") + + if args.backend == "mii" and args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.backend == "mii" and args.n != 1: + raise ValueError("n must be 1 for MII backend.") + if args.backend == "mii" and args.tokenizer != args.model: + raise ValueError( + "Tokenizer must be the same as the model for MII backend.") + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii", "vllm-chat"], + default="vllm") + parser.add_argument( + "--dataset-name", + type=str, + choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], + help="Name of the dataset to benchmark on.", + default="sharegpt") + parser.add_argument( + "--dataset", + type=str, + default=None, + help="Path to the ShareGPT dataset, will be deprecated in\ + the next release. The dataset is expected to " + "be a json in form of list[dict[..., conversations: " + "list[dict[..., value: ]]]]") + parser.add_argument("--dataset-path", + type=str, + default=None, + help="Path to the dataset") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.") + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument("--async-engine", + action='store_true', + default=False, + help="Use vLLM async engine rather than LLM class.") + parser.add_argument("--disable-frontend-multiprocessing", + action='store_true', + default=False, + help="Disable decoupled async engine frontend.") + parser.add_argument( + "--disable-detokenize", + action="store_true", + help=("Do not detokenize the response (i.e. do not include " + "detokenization time in the measurement)")) + # LoRA + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to the lora adapters to use. This can be an absolute path, " + "a relative path, or a Hugging Face model identifier.") + parser.add_argument( + "--prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before the random " + "context in a request (default: 0).", + ) + # random dataset + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range ratio for sampling input/output length, " + "used only for RandomDataset. Must be in the range [0, 1) to define " + "a symmetric sampling range " + "[length * (1 - range_ratio), length * (1 + range_ratio)].", + ) + + # hf dtaset + parser.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + parser.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + + parser = AsyncEngineArgs.add_cli_args(parser) + + +def main(args: argparse.Namespace): + if args.tokenizer is None: + args.tokenizer = args.model + validate_args(args) + if args.seed is None: + args.seed = 0 + print(args) + random.seed(args.seed) + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + requests = get_requests(args, tokenizer) + is_multi_modal = any(request.multi_modal_data is not None + for request in requests) + request_outputs: Optional[list[RequestOutput]] = None + if args.backend == "vllm": + if args.async_engine: + elapsed_time = uvloop.run( + run_vllm_async( + requests, + args.n, + AsyncEngineArgs.from_cli_args(args), + args.disable_frontend_multiprocessing, + args.disable_detokenize, + )) + else: + elapsed_time, request_outputs = run_vllm( + requests, args.n, EngineArgs.from_cli_args(args), + args.disable_detokenize) + elif args.backend == "hf": + assert args.tensor_parallel_size == 1 + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, + args.hf_max_batch_size, args.trust_remote_code, + args.disable_detokenize) + elif args.backend == "vllm-chat": + elapsed_time, request_outputs = run_vllm_chat( + requests, args.n, EngineArgs.from_cli_args(args), + args.disable_detokenize) + else: + raise ValueError(f"Unknown backend: {args.backend}") + + if request_outputs: + # Note: with the vllm and vllm-chat backends, + # we have request_outputs, which we use to count tokens. + total_prompt_tokens = 0 + total_output_tokens = 0 + for ro in request_outputs: + if not isinstance(ro, RequestOutput): + continue + total_prompt_tokens += len( + ro.prompt_token_ids) if ro.prompt_token_ids else 0 + total_output_tokens += sum( + len(o.token_ids) for o in ro.outputs if o) + total_num_tokens = total_prompt_tokens + total_output_tokens + else: + total_num_tokens = sum(r.prompt_len + r.expected_output_len + for r in requests) + total_output_tokens = sum(r.expected_output_len for r in requests) + total_prompt_tokens = total_num_tokens - total_output_tokens + + if is_multi_modal and args.backend != "vllm-chat": + print("\033[91mWARNING\033[0m: Multi-modal request with " + f"{args.backend} backend detected. The " + "following metrics are not accurate because image tokens are not" + " counted. See vllm-project/vllm/issues/9778 for details.") + # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. + # vllm-chat backend counts the image tokens now + + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + print(f"Total num prompt tokens: {total_prompt_tokens}") + print(f"Total num output tokens: {total_output_tokens}") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) diff --git a/vllm/entrypoints/cli/benchmark/latency.py b/vllm/entrypoints/cli/benchmark/latency.py new file mode 100644 index 00000000..5aca16e0 --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/latency.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse + +from vllm.benchmarks.latency import add_cli_args, main +from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase +from vllm.entrypoints.cli.types import CLISubcommand + + +class BenchmarkLatencySubcommand(BenchmarkSubcommandBase): + """ The `latency` subcommand for vllm bench. """ + + def __init__(self): + self.name = "latency" + super().__init__() + + @property + def help(self) -> str: + return "Benchmark the latency of a single batch of requests." + + def add_cli_args(self, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) + + +def cmd_init() -> list[CLISubcommand]: + return [BenchmarkLatencySubcommand()] diff --git a/vllm/entrypoints/cli/benchmark/main.py b/vllm/entrypoints/cli/benchmark/main.py index 1bcb25be..9e857af7 100644 --- a/vllm/entrypoints/cli/benchmark/main.py +++ b/vllm/entrypoints/cli/benchmark/main.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import vllm.entrypoints.cli.benchmark.latency import vllm.entrypoints.cli.benchmark.serve +import vllm.entrypoints.cli.benchmark.throughput from vllm.entrypoints.cli.types import CLISubcommand from vllm.utils import FlexibleArgumentParser -# TODO: Add the rest of the benchmark subcommands here, -# e.g., throughput, latency, etc. BENCHMARK_CMD_MODULES = [ + vllm.entrypoints.cli.benchmark.latency, vllm.entrypoints.cli.benchmark.serve, + vllm.entrypoints.cli.benchmark.throughput, ] diff --git a/vllm/entrypoints/cli/benchmark/throughput.py b/vllm/entrypoints/cli/benchmark/throughput.py new file mode 100644 index 00000000..88ee6aa0 --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/throughput.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse + +from vllm.benchmarks.throughput import add_cli_args, main +from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase +from vllm.entrypoints.cli.types import CLISubcommand + + +class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase): + """ The `throughput` subcommand for vllm bench. """ + + def __init__(self): + self.name = "throughput" + super().__init__() + + @property + def help(self) -> str: + return "Benchmark offline inference throughput." + + def add_cli_args(self, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) + + +def cmd_init() -> list[CLISubcommand]: + return [BenchmarkThroughputSubcommand()]