diff --git a/benchmarks/README.md b/benchmarks/README.md
index edc10d8b..c64c24fd 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -43,20 +43,26 @@ become available.
HuggingFace |
✅ |
- 🚧 |
+ 🟡 |
Specify your dataset path on HuggingFace |
VisionArena |
✅ |
- 🚧 |
+ ✅ |
lmarena-ai/vision-arena-bench-v0.1 (a HuggingFace dataset) |
-✅: supported
+
+✅: supported
+
🚧: to be supported
+🟡: Partial support. Currently, HuggingFaceDataset only supports dataset formats
+similar to `lmms-lab/LLaVA-OneVision-Data`. If you need support for other dataset
+formats, please consider contributing.
+
**Note**: VisionArena’s `dataset-name` should be set to `hf`
---
@@ -79,7 +85,7 @@ NUM_PROMPTS=10
BACKEND="openai-chat"
DATASET_NAME="sharegpt"
DATASET_PATH="/ShareGPT_V3_unfiltered_cleaned_split.json"
-python3 benchmarks/benchmark_serving.py --backend ${BACKEND} --model ${MODEL_NAME} --endpoint /v1/chat/completions --dataset-name ${DATASET_NAME} --dataset-path ${DATASET_PATH} --num-prompts ${NUM_PROMPTS}
+python3 vllm/benchmarks/benchmark_serving.py --backend ${BACKEND} --model ${MODEL_NAME} --endpoint /v1/chat/completions --dataset-name ${DATASET_NAME} --dataset-path ${DATASET_PATH} --num-prompts ${NUM_PROMPTS}
```
If successful, you will see the following output
@@ -123,7 +129,7 @@ DATASET_NAME="hf"
DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1"
DATASET_SPLIT='train'
-python3 benchmarks/benchmark_serving.py \
+python3 vllm/benchmarks/benchmark_serving.py \
--backend "${BACKEND}" \
--model "${MODEL_NAME}" \
--endpoint "/v1/chat/completions" \
@@ -140,35 +146,65 @@ python3 benchmarks/benchmark_serving.py \
MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
NUM_PROMPTS=10
DATASET_NAME="sonnet"
-DATASET_PATH="benchmarks/sonnet.txt"
+DATASET_PATH="vllm/benchmarks/sonnet.txt"
-python3 benchmarks/benchmark_throughput.py \
+python3 vllm/benchmarks/benchmark_throughput.py \
--model "${MODEL_NAME}" \
--dataset-name "${DATASET_NAME}" \
--dataset-path "${DATASET_PATH}" \
--num-prompts "${NUM_PROMPTS}"
- ```
+```
If successful, you will see the following output
```
-Throughput: 7.35 requests/s, 4789.20 total tokens/s, 1102.83 output tokens/s
+Throughput: 7.15 requests/s, 4656.00 total tokens/s, 1072.15 output tokens/s
+Total num prompt tokens: 5014
+Total num output tokens: 1500
+```
+
+### VisionArena Benchmark for Vision Language Models
+
+``` bash
+MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
+NUM_PROMPTS=10
+DATASET_NAME="hf"
+DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1"
+DATASET_SPLIT="train"
+
+python3 vllm/benchmarks/benchmark_throughput.py \
+ --model "${MODEL_NAME}" \
+ --backend "vllm-chat" \
+ --dataset-name "${DATASET_NAME}" \
+ --dataset-path "${DATASET_PATH}" \
+ --num-prompts "${NUM_PROMPTS}" \
+ --hf-split "${DATASET_SPLIT}"
+```
+
+The `num prompt tokens` now includes image token counts
+
+```
+Throughput: 2.55 requests/s, 4036.92 total tokens/s, 326.90 output tokens/s
+Total num prompt tokens: 14527
+Total num output tokens: 1280
```
### Benchmark with LoRA Adapters
``` bash
+# download dataset
+# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
MODEL_NAME="meta-llama/Llama-2-7b-hf"
BACKEND="vllm"
DATASET_NAME="sharegpt"
-DATASET_PATH="/home/jovyan/data/vllm_benchmark_datasets/ShareGPT_V3_unfiltered_cleaned_split.json"
+DATASET_PATH="/ShareGPT_V3_unfiltered_cleaned_split.json"
NUM_PROMPTS=10
MAX_LORAS=2
MAX_LORA_RANK=8
ENABLE_LORA="--enable-lora"
LORA_PATH="yard1/llama-2-7b-sql-lora-test"
-python3 benchmarks/benchmark_throughput.py \
+python3 vllm/benchmarks/benchmark_throughput.py \
--model "${MODEL_NAME}" \
--backend "${BACKEND}" \
--dataset_path "${DATASET_PATH}" \
diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py
index 30fffdda..55109dab 100644
--- a/benchmarks/benchmark_dataset.py
+++ b/benchmarks/benchmark_dataset.py
@@ -46,7 +46,7 @@ class SampleRequest:
Represents a single inference request for benchmarking.
"""
- prompt: str
+ prompt: Union[str, Any]
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None
@@ -84,6 +84,20 @@ class BenchmarkDataset(ABC):
if random_seed is not None else self.DEFAULT_SEED)
self.data = None
+ def apply_multimodal_chat_transformation(
+ self,
+ prompt: str,
+ mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
+ """
+ Transform a prompt and optional multimodal content into a chat format.
+ This method is used for chat models that expect a specific
+ conversation format.
+ """
+ content = [{"text": prompt, "type": "text"}]
+ if mm_content is not None:
+ content.append(mm_content)
+ return [{"role": "user", "content": content}]
+
def load_data(self) -> None:
"""
Load data from the dataset path into self.data.
@@ -338,6 +352,7 @@ class ShareGPTDataset(BenchmarkDataset):
lora_path: Optional[str] = None,
max_loras: Optional[int] = None,
output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
**kwargs) -> list:
samples: list = []
for entry in self.data:
@@ -358,6 +373,9 @@ class ShareGPTDataset(BenchmarkDataset):
skip_min_output_len_check=output_len
is not None):
continue
+ if enable_multimodal_chat:
+ prompt = self.apply_multimodal_chat_transformation(
+ prompt, None)
samples.append(
SampleRequest(
prompt=prompt,
@@ -550,10 +568,13 @@ class HuggingFaceDataset(BenchmarkDataset):
split=self.dataset_split,
streaming=True,
)
-
- if "conversations" not in self.data.features:
- raise ValueError("HF Dataset must have a 'conversations' column.")
-
+ if self.data.features is None or "conversations" \
+ not in self.data.features:
+ raise ValueError(
+ "HuggingFaceDataset currently only supports datasets with "
+ "a 'conversations' column like lmms-lab/LLaVA-OneVision-Data. "
+ "Please consider contributing if you would like to add "
+ "support for additional dataset formats.")
# Shuffle and filter examples with at least 2 conversations.
self.data = self.data.shuffle(seed=self.random_seed).filter(
lambda x: len(x["conversations"]) >= 2)
@@ -561,9 +582,8 @@ class HuggingFaceDataset(BenchmarkDataset):
def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
- lora_path: Optional[str] = None,
- max_loras: Optional[int] = None,
output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
**kwargs) -> list:
sampled_requests = []
dynamic_output = output_len is None
@@ -571,13 +591,9 @@ class HuggingFaceDataset(BenchmarkDataset):
for item in self.data:
if len(sampled_requests) >= num_requests:
break
-
conv = item["conversations"]
prompt, completion = conv[0]["value"], conv[1]["value"]
- lora_request, tokenizer = self.get_random_lora_request(
- tokenizer, lora_path=lora_path, max_loras=max_loras)
-
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
@@ -587,16 +603,20 @@ class HuggingFaceDataset(BenchmarkDataset):
if dynamic_output and not is_valid_sequence(
prompt_len, completion_len):
continue
-
mm_content = process_image(
item["image"]) if "image" in item else None
+ if enable_multimodal_chat:
+ # Note: when chat is enabled the request prompt_len is no longer
+ # accurate and we will be using request output to count the
+ # actual prompt len and output len
+ prompt = self.apply_multimodal_chat_transformation(
+ prompt, mm_content)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
- lora_request=lora_request,
))
return sampled_requests
@@ -606,7 +626,7 @@ class HuggingFaceDataset(BenchmarkDataset):
# -----------------------------------------------------------------------------
-class VisionArenaDataset(BenchmarkDataset):
+class VisionArenaDataset(HuggingFaceDataset):
"""
Vision Arena Dataset.
"""
@@ -617,14 +637,9 @@ class VisionArenaDataset(BenchmarkDataset):
def __init__(
self,
- dataset_split: str,
- dataset_subset: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
- self.dataset_split = dataset_split
- self.dataset_subset = dataset_subset
-
if self.dataset_path != self.VISION_ARENA_DATASET_PATH:
raise ValueError(f"Only support Vision Arena dataset.\
This data path {self.dataset_path} is not valid.")
@@ -645,9 +660,9 @@ class VisionArenaDataset(BenchmarkDataset):
def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
- output_len: int = DEFAULT_OUTPUT_LEN,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
**kwargs) -> list:
- # TODO (jenniferzhao): Add support for offline benchmark sampling
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
@@ -655,8 +670,14 @@ class VisionArenaDataset(BenchmarkDataset):
if len(sampled_requests) >= num_requests:
break
prompt = item["turns"][0][0]["content"]
- prompt_len = len(tokenizer(prompt).input_ids)
mm_content = process_image(item["images"][0])
+ prompt_len = len(tokenizer(prompt).input_ids)
+ if enable_multimodal_chat:
+ # Note: when chat is enabled the request prompt_len is no longer
+ # accurate and we will be using request output to count the
+ # actual prompt len
+ prompt = self.apply_multimodal_chat_transformation(
+ prompt, mm_content)
sampled_requests.append(
SampleRequest(
prompt=prompt,
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index 7e655673..53869db4 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -11,8 +11,9 @@ from typing import Any, Optional, Union
import torch
import uvloop
-from benchmark_dataset import (BurstGPTDataset, RandomDataset, SampleRequest,
- ShareGPTDataset, SonnetDataset)
+from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
+ RandomDataset, SampleRequest, ShareGPTDataset,
+ SonnetDataset, VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
@@ -23,6 +24,7 @@ from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
+from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
@@ -32,7 +34,7 @@ def run_vllm(
n: int,
engine_args: EngineArgs,
disable_detokenize: bool = False,
-) -> float:
+) -> tuple[float, Optional[list[RequestOutput]]]:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
@@ -66,12 +68,13 @@ def run_vllm(
use_beam_search = False
+ outputs = None
if not use_beam_search:
start = time.perf_counter()
- llm.generate(prompts,
- sampling_params,
- lora_request=lora_requests,
- use_tqdm=True)
+ outputs = llm.generate(prompts,
+ sampling_params,
+ lora_request=lora_requests,
+ use_tqdm=True)
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
@@ -89,7 +92,46 @@ def run_vllm(
ignore_eos=True,
))
end = time.perf_counter()
- return end - start
+ return end - start, outputs
+
+
+def run_vllm_chat(
+ requests: list[SampleRequest],
+ n: int,
+ engine_args: EngineArgs,
+ disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
+ """
+ Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
+ multimodal models as it properly handles multimodal inputs and chat
+ formatting. For non-multimodal models, use run_vllm() instead.
+ """
+ from vllm import LLM, SamplingParams
+ llm = LLM(**dataclasses.asdict(engine_args))
+
+ assert all(
+ llm.llm_engine.model_config.max_model_len >= (
+ request.prompt_len + request.expected_output_len)
+ for request in requests), (
+ "Please ensure that max_model_len is greater than the sum of "
+ "prompt_len and expected_output_len for all requests.")
+
+ prompts = []
+ sampling_params: list[SamplingParams] = []
+ for request in requests:
+ prompts.append(request.prompt)
+ sampling_params.append(
+ SamplingParams(
+ n=n,
+ temperature=1.0,
+ top_p=1.0,
+ ignore_eos=True,
+ max_tokens=request.expected_output_len,
+ detokenize=not disable_detokenize,
+ ))
+ start = time.perf_counter()
+ outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
+ end = time.perf_counter()
+ return end - start, outputs
async def run_vllm_async(
@@ -264,6 +306,8 @@ def get_requests(args, tokenizer):
dataset_cls = RandomDataset
elif args.dataset_name == "sharegpt":
dataset_cls = ShareGPTDataset
+ if args.backend == "vllm-chat":
+ sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_name == "sonnet":
assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset.")
@@ -272,6 +316,19 @@ def get_requests(args, tokenizer):
sample_kwargs["return_prompt_formatted"] = True
elif args.dataset_name == "burstgpt":
dataset_cls = BurstGPTDataset
+ elif args.dataset_name == "hf":
+ if args.backend != "vllm-chat":
+ raise ValueError(
+ "hf datasets only are supported by vllm-chat backend")
+ # Choose between VisionArenaDataset and HuggingFaceDataset based on
+ # provided parameters.
+ dataset_cls = (VisionArenaDataset if args.dataset_path
+ == VisionArenaDataset.VISION_ARENA_DATASET_PATH
+ and args.hf_subset is None else HuggingFaceDataset)
+ common_kwargs['dataset_subset'] = args.hf_subset
+ common_kwargs['dataset_split'] = args.hf_split
+ sample_kwargs["enable_multimodal_chat"] = True
+
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
@@ -290,6 +347,7 @@ def main(args: argparse.Namespace):
requests = get_requests(args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None
for request in requests)
+ request_outputs: Optional[list[RequestOutput]] = None
if args.backend == "vllm":
if args.async_engine:
elapsed_time = uvloop.run(
@@ -301,9 +359,9 @@ def main(args: argparse.Namespace):
args.disable_detokenize,
))
else:
- elapsed_time = run_vllm(requests, args.n,
- EngineArgs.from_cli_args(args),
- args.disable_detokenize)
+ elapsed_time, request_outputs = run_vllm(
+ requests, args.n, EngineArgs.from_cli_args(args),
+ args.disable_detokenize)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -312,20 +370,45 @@ def main(args: argparse.Namespace):
elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len)
+ elif args.backend == "vllm-chat":
+ elapsed_time, request_outputs = run_vllm_chat(
+ requests, args.n, EngineArgs.from_cli_args(args),
+ args.disable_detokenize)
else:
raise ValueError(f"Unknown backend: {args.backend}")
- total_num_tokens = sum(request.prompt_len + request.expected_output_len
- for request in requests)
- total_output_tokens = sum(request.expected_output_len
- for request in requests)
- if is_multi_modal:
- print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
+
+ if request_outputs:
+ # Note: with the vllm and vllm-chat backends,
+ # we have request_outputs, which we use to count tokens.
+ total_prompt_tokens = 0
+ total_output_tokens = 0
+ for ro in request_outputs:
+ if not isinstance(ro, RequestOutput):
+ continue
+ total_prompt_tokens += len(
+ ro.prompt_token_ids) if ro.prompt_token_ids else 0
+ total_output_tokens += sum(
+ len(o.token_ids) for o in ro.outputs if o)
+ total_num_tokens = total_prompt_tokens + total_output_tokens
+ else:
+ total_num_tokens = sum(r.prompt_len + r.expected_output_len
+ for r in requests)
+ total_output_tokens = sum(r.expected_output_len for r in requests)
+ total_prompt_tokens = total_num_tokens - total_output_tokens
+
+ if is_multi_modal and args.backend != "vllm-chat":
+ print("\033[91mWARNING\033[0m: Multi-modal request with "
+ f"{args.backend} backend detected. The "
"following metrics are not accurate because image tokens are not"
" counted. See vllm-project/vllm/issues/9778 for details.")
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
+ # vllm-chat backend counts the image tokens now
+
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
+ print(f"Total num prompt tokens: {total_prompt_tokens}")
+ print(f"Total num output tokens: {total_output_tokens}")
# Output JSON results if specified
if args.output_json:
@@ -341,17 +424,100 @@ def main(args: argparse.Namespace):
save_to_pytorch_benchmark_format(args, results)
+def validate_args(args):
+ """
+ Validate command-line arguments.
+ """
+
+ # === Deprecation and Defaulting ===
+ if args.dataset is not None:
+ warnings.warn(
+ "The '--dataset' argument will be deprecated in the next release. "
+ "Please use '--dataset-name' and '--dataset-path' instead.",
+ stacklevel=2)
+ args.dataset_path = args.dataset
+
+ if not getattr(args, "tokenizer", None):
+ args.tokenizer = args.model
+
+ # === Backend Validation ===
+ valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
+ if args.backend not in valid_backends:
+ raise ValueError(f"Unsupported backend: {args.backend}")
+
+ # === Dataset Configuration ===
+ if not args.dataset and not args.dataset_path:
+ print(
+ "When dataset path is not set, it will default to random dataset")
+ args.dataset_name = 'random'
+ if args.input_len is None:
+ raise ValueError("input_len must be provided for a random dataset")
+
+ # === Dataset Name Specific Checks ===
+ # --hf-subset and --hf-split: only used
+ # when dataset_name is 'hf'
+ if args.dataset_name != "hf" and (
+ getattr(args, "hf_subset", None) is not None
+ or getattr(args, "hf_split", None) is not None):
+ warnings.warn("--hf-subset and --hf-split will be ignored \
+ since --dataset-name is not 'hf'.",
+ stacklevel=2)
+ elif args.dataset_name == "hf" and args.backend != "vllm-chat":
+ raise ValueError(
+ "When --dataset-name is 'hf', backend must be 'vllm-chat'")
+
+ # --random-range-ratio: only used when dataset_name is 'random'
+ if args.dataset_name != 'random' and args.random_range_ratio is not None:
+ warnings.warn("--random-range-ratio will be ignored since \
+ --dataset-name is not 'random'.",
+ stacklevel=2)
+
+ # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
+ # set.
+ if args.dataset_name not in {"random", "sonnet", None
+ } and args.prefix_len is not None:
+ warnings.warn("--prefix-len will be ignored since --dataset-name\
+ is not 'random', 'sonnet', or not set.",
+ stacklevel=2)
+
+ # === LoRA Settings ===
+ if getattr(args, "enable_lora", False) and args.backend != "vllm":
+ raise ValueError(
+ "LoRA benchmarking is only supported for vLLM backend")
+ if getattr(args, "enable_lora", False) and args.lora_path is None:
+ raise ValueError("LoRA path must be provided when enable_lora is True")
+
+ # === Backend-specific Validations ===
+ if args.backend == "hf" and args.hf_max_batch_size is None:
+ raise ValueError("HF max batch size is required for HF backend")
+ if args.backend != "hf" and args.hf_max_batch_size is not None:
+ raise ValueError("HF max batch size is only for HF backend.")
+
+ if args.backend in {"hf", "mii"} and getattr(args, "quantization",
+ None) is not None:
+ raise ValueError("Quantization is only for vLLM backend.")
+
+ if args.backend == "mii" and args.dtype != "auto":
+ raise ValueError("dtype must be auto for MII backend.")
+ if args.backend == "mii" and args.n != 1:
+ raise ValueError("n must be 1 for MII backend.")
+ if args.backend == "mii" and args.tokenizer != args.model:
+ raise ValueError(
+ "Tokenizer must be the same as the model for MII backend.")
+
+
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
type=str,
- choices=["vllm", "hf", "mii"],
+ choices=["vllm", "hf", "mii", "vllm-chat"],
default="vllm")
- parser.add_argument("--dataset-name",
- type=str,
- choices=["sharegpt", "random", "sonnet", "burstgpt"],
- help="Name of the dataset to benchmark on.",
- default="sharegpt")
+ parser.add_argument(
+ "--dataset-name",
+ type=str,
+ choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
+ help="Name of the dataset to benchmark on.",
+ default="sharegpt")
parser.add_argument(
"--dataset",
type=str,
@@ -419,55 +585,24 @@ if __name__ == "__main__":
parser.add_argument(
"--random-range-ratio",
type=float,
- default=1.0,
+ default=None,
help="Range of sampled ratio of input/output length, "
"used only for RandomDataSet.",
)
+ # hf dtaset
+ parser.add_argument("--hf-subset",
+ type=str,
+ default=None,
+ help="Subset of the HF dataset.")
+ parser.add_argument("--hf-split",
+ type=str,
+ default=None,
+ help="Split of the HF dataset.")
+
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
- if args.dataset is not None:
- warnings.warn(
- "The '--dataset' argument will be deprecated in the next "
- "release. Please use '--dataset-name' and "
- "'--dataset-path' in the future runs.",
- stacklevel=2)
- args.dataset_path = args.dataset
- if args.dataset is None and args.dataset_path is None:
- # for random dataset, the default sampling setting is in
- # benchmark_dataset.RandomDataset
- print("When dataset is not set, it will default to random dataset")
- else:
- assert args.input_len is None
- if args.enable_lora:
- assert args.lora_path is not None
-
- if args.backend == "vllm":
- if args.hf_max_batch_size is not None:
- raise ValueError("HF max batch size is only for HF backend.")
- elif args.backend == "hf":
- if args.hf_max_batch_size is None:
- raise ValueError("HF max batch size is required for HF backend.")
- if args.quantization is not None:
- raise ValueError("Quantization is only for vLLM backend.")
- if args.enable_lora is not None:
- raise ValueError("LoRA benchmarking is only supported for vLLM"
- " backend")
- elif args.backend == "mii":
- if args.dtype != "auto":
- raise ValueError("dtype must be auto for MII backend.")
- if args.n != 1:
- raise ValueError("n must be 1 for MII backend.")
- if args.quantization is not None:
- raise ValueError("Quantization is only for vLLM backend.")
- if args.hf_max_batch_size is not None:
- raise ValueError("HF max batch size is only for HF backend.")
- if args.tokenizer != args.model:
- raise ValueError("Tokenizer must be the same as the model for MII "
- "backend.")
- if args.enable_lora is not None:
- raise ValueError("LoRA benchmarking is only supported for vLLM"
- " backend")
+ validate_args(args)
main(args)