2023-06-14 19:55:38 -07:00
|
|
|
"""Benchmark offline inference throughput."""
|
2023-05-28 03:20:05 -07:00
|
|
|
import argparse
|
2024-10-22 17:40:38 -05:00
|
|
|
import dataclasses
|
2023-05-28 03:20:05 -07:00
|
|
|
import json
|
|
|
|
import random
|
|
|
|
import time
|
2024-11-04 14:32:16 -08:00
|
|
|
from typing import List, Optional
|
2023-05-28 03:20:05 -07:00
|
|
|
|
2023-06-14 19:55:38 -07:00
|
|
|
import torch
|
2024-09-03 17:57:41 -07:00
|
|
|
import uvloop
|
2024-03-25 23:59:47 +09:00
|
|
|
from tqdm import tqdm
|
2023-11-14 12:35:30 -08:00
|
|
|
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
|
|
|
PreTrainedTokenizerBase)
|
2023-06-14 19:55:38 -07:00
|
|
|
|
2024-10-22 17:40:38 -05:00
|
|
|
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
2024-09-03 17:57:41 -07:00
|
|
|
from vllm.entrypoints.openai.api_server import (
|
|
|
|
build_async_engine_client_from_engine_args)
|
2024-11-04 14:32:16 -08:00
|
|
|
from vllm.inputs import TextPrompt
|
|
|
|
from vllm.multimodal import MultiModalDataDict
|
2024-10-05 23:39:03 -07:00
|
|
|
from vllm.sampling_params import BeamSearchParams
|
2024-09-03 17:57:41 -07:00
|
|
|
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
2024-04-18 03:21:55 -04:00
|
|
|
|
2023-05-28 03:20:05 -07:00
|
|
|
|
2024-11-04 14:32:16 -08:00
|
|
|
@dataclasses.dataclass
|
|
|
|
class SampleRequest:
|
|
|
|
"""A class representing a single inference request for benchmarking.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
prompt: The input text prompt for the model.
|
|
|
|
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
|
|
|
|
images).
|
|
|
|
prompt_len: The length of the prompt in tokens.
|
|
|
|
expected_output_len: The expected length of the output in tokens.
|
|
|
|
"""
|
|
|
|
prompt: str
|
|
|
|
prompt_len: int
|
|
|
|
expected_output_len: int
|
|
|
|
multi_modal_data: Optional[MultiModalDataDict] = None
|
|
|
|
|
|
|
|
|
2023-05-28 03:20:05 -07:00
|
|
|
def sample_requests(
|
|
|
|
dataset_path: str,
|
|
|
|
num_requests: int,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
2023-11-14 12:35:30 -08:00
|
|
|
fixed_output_len: Optional[int],
|
2024-11-04 14:32:16 -08:00
|
|
|
) -> List[SampleRequest]:
|
2023-11-20 11:58:01 -08:00
|
|
|
if fixed_output_len is not None and fixed_output_len < 4:
|
|
|
|
raise ValueError("output_len too small")
|
2023-11-14 12:35:30 -08:00
|
|
|
|
2023-05-28 03:20:05 -07:00
|
|
|
# Load the dataset.
|
|
|
|
with open(dataset_path) as f:
|
|
|
|
dataset = json.load(f)
|
|
|
|
# Filter out the conversations with less than 2 turns.
|
2023-09-16 00:03:37 -07:00
|
|
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
2023-05-28 03:20:05 -07:00
|
|
|
# Only keep the first two turns of each conversation.
|
2023-09-16 00:03:37 -07:00
|
|
|
dataset = [(data["conversations"][0]["value"],
|
|
|
|
data["conversations"][1]["value"]) for data in dataset]
|
2023-05-28 03:20:05 -07:00
|
|
|
|
2024-04-04 17:56:22 +08:00
|
|
|
# Shuffle the dataset.
|
|
|
|
random.shuffle(dataset)
|
2023-06-14 19:55:38 -07:00
|
|
|
|
2024-04-04 17:56:22 +08:00
|
|
|
# Filter out sequences that are too long or too short
|
2024-11-04 14:32:16 -08:00
|
|
|
filtered_dataset: List[SampleRequest] = []
|
2024-04-04 17:56:22 +08:00
|
|
|
for i in range(len(dataset)):
|
|
|
|
if len(filtered_dataset) == num_requests:
|
|
|
|
break
|
|
|
|
|
|
|
|
# Tokenize the prompts and completions.
|
|
|
|
prompt = dataset[i][0]
|
|
|
|
prompt_token_ids = tokenizer(prompt).input_ids
|
|
|
|
completion = dataset[i][1]
|
|
|
|
completion_token_ids = tokenizer(completion).input_ids
|
2023-06-14 19:55:38 -07:00
|
|
|
prompt_len = len(prompt_token_ids)
|
2024-04-04 17:56:22 +08:00
|
|
|
output_len = len(completion_token_ids
|
|
|
|
) if fixed_output_len is None else fixed_output_len
|
2023-06-14 19:55:38 -07:00
|
|
|
if prompt_len < 4 or output_len < 4:
|
|
|
|
# Prune too short sequences.
|
|
|
|
continue
|
|
|
|
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
|
|
|
# Prune too long sequences.
|
|
|
|
continue
|
2024-11-04 14:32:16 -08:00
|
|
|
filtered_dataset.append(
|
|
|
|
SampleRequest(prompt=prompt,
|
|
|
|
prompt_len=prompt_len,
|
|
|
|
expected_output_len=output_len))
|
2023-05-28 03:20:05 -07:00
|
|
|
|
2024-04-04 17:56:22 +08:00
|
|
|
return filtered_dataset
|
2023-05-28 03:20:05 -07:00
|
|
|
|
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
def run_vllm(
|
2024-11-04 14:32:16 -08:00
|
|
|
requests: List[SampleRequest],
|
2023-06-14 19:55:38 -07:00
|
|
|
n: int,
|
2024-10-22 17:40:38 -05:00
|
|
|
engine_args: EngineArgs,
|
2023-06-14 19:55:38 -07:00
|
|
|
) -> float:
|
2023-11-14 12:35:30 -08:00
|
|
|
from vllm import LLM, SamplingParams
|
2024-10-22 17:40:38 -05:00
|
|
|
llm = LLM(**dataclasses.asdict(engine_args))
|
2023-05-28 03:20:05 -07:00
|
|
|
|
2023-06-17 17:25:21 +08:00
|
|
|
# Add the requests to the engine.
|
2024-11-04 14:32:16 -08:00
|
|
|
prompts: List[TextPrompt] = []
|
2024-06-15 12:45:31 +08:00
|
|
|
sampling_params: List[SamplingParams] = []
|
2024-11-04 14:32:16 -08:00
|
|
|
for request in requests:
|
|
|
|
prompts.append(TextPrompt(prompt=request.prompt))
|
2024-04-24 14:10:24 -07:00
|
|
|
sampling_params.append(
|
|
|
|
SamplingParams(
|
|
|
|
n=n,
|
2024-10-06 22:47:04 -07:00
|
|
|
temperature=1.0,
|
2024-04-24 14:10:24 -07:00
|
|
|
top_p=1.0,
|
|
|
|
ignore_eos=True,
|
2024-11-04 14:32:16 -08:00
|
|
|
max_tokens=request.expected_output_len,
|
2024-04-24 14:10:24 -07:00
|
|
|
))
|
2023-05-28 03:20:05 -07:00
|
|
|
|
2024-10-06 22:47:04 -07:00
|
|
|
use_beam_search = False
|
|
|
|
|
|
|
|
if not use_beam_search:
|
2024-09-23 22:08:12 -07:00
|
|
|
start = time.perf_counter()
|
|
|
|
llm.generate(prompts, sampling_params, use_tqdm=True)
|
|
|
|
end = time.perf_counter()
|
|
|
|
else:
|
2024-11-04 14:32:16 -08:00
|
|
|
prompts = [request.prompt for request in requests]
|
2024-09-23 22:08:12 -07:00
|
|
|
# output_len should be the same for all requests.
|
|
|
|
output_len = requests[0][2]
|
2024-11-04 14:32:16 -08:00
|
|
|
for request in requests:
|
|
|
|
assert request.expected_output_len == output_len
|
2024-09-23 22:08:12 -07:00
|
|
|
start = time.perf_counter()
|
2024-10-05 23:39:03 -07:00
|
|
|
llm.beam_search(
|
|
|
|
prompts,
|
|
|
|
BeamSearchParams(
|
|
|
|
beam_width=n,
|
|
|
|
max_tokens=output_len,
|
|
|
|
ignore_eos=True,
|
|
|
|
))
|
2024-09-23 22:08:12 -07:00
|
|
|
end = time.perf_counter()
|
2023-06-14 19:55:38 -07:00
|
|
|
return end - start
|
|
|
|
|
|
|
|
|
2024-09-03 17:57:41 -07:00
|
|
|
async def run_vllm_async(
|
2024-11-04 14:32:16 -08:00
|
|
|
requests: List[SampleRequest],
|
2024-09-03 17:57:41 -07:00
|
|
|
n: int,
|
2024-10-22 17:40:38 -05:00
|
|
|
engine_args: AsyncEngineArgs,
|
2024-09-03 17:57:41 -07:00
|
|
|
disable_frontend_multiprocessing: bool = False,
|
|
|
|
) -> float:
|
|
|
|
from vllm import SamplingParams
|
|
|
|
|
|
|
|
async with build_async_engine_client_from_engine_args(
|
|
|
|
engine_args, disable_frontend_multiprocessing) as llm:
|
|
|
|
|
|
|
|
# Add the requests to the engine.
|
2024-11-04 14:32:16 -08:00
|
|
|
prompts: List[TextPrompt] = []
|
2024-09-03 17:57:41 -07:00
|
|
|
sampling_params: List[SamplingParams] = []
|
2024-11-04 14:32:16 -08:00
|
|
|
for request in requests:
|
|
|
|
prompts.append(TextPrompt(prompt=request.prompt))
|
2024-09-03 17:57:41 -07:00
|
|
|
sampling_params.append(
|
|
|
|
SamplingParams(
|
|
|
|
n=n,
|
2024-10-06 22:47:04 -07:00
|
|
|
temperature=1.0,
|
2024-09-03 17:57:41 -07:00
|
|
|
top_p=1.0,
|
|
|
|
ignore_eos=True,
|
2024-11-04 14:32:16 -08:00
|
|
|
max_tokens=request.expected_output_len,
|
2024-09-03 17:57:41 -07:00
|
|
|
))
|
|
|
|
|
|
|
|
generators = []
|
|
|
|
start = time.perf_counter()
|
|
|
|
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
|
|
|
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
|
|
|
generators.append(generator)
|
|
|
|
all_gens = merge_async_iterators(*generators)
|
|
|
|
async for i, res in all_gens:
|
|
|
|
pass
|
|
|
|
end = time.perf_counter()
|
|
|
|
return end - start
|
|
|
|
|
|
|
|
|
2023-06-14 19:55:38 -07:00
|
|
|
def run_hf(
|
2024-11-04 14:32:16 -08:00
|
|
|
requests: List[SampleRequest],
|
2023-06-14 19:55:38 -07:00
|
|
|
model: str,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
|
n: int,
|
|
|
|
max_batch_size: int,
|
2023-07-20 08:02:40 +08:00
|
|
|
trust_remote_code: bool,
|
2023-06-14 19:55:38 -07:00
|
|
|
) -> float:
|
2023-09-16 00:03:37 -07:00
|
|
|
llm = AutoModelForCausalLM.from_pretrained(
|
|
|
|
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
2023-06-28 09:46:58 -07:00
|
|
|
if llm.config.model_type == "llama":
|
|
|
|
# To enable padding in the HF backend.
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
2023-06-14 19:55:38 -07:00
|
|
|
llm = llm.cuda()
|
|
|
|
|
|
|
|
pbar = tqdm(total=len(requests))
|
2023-10-02 19:22:05 -07:00
|
|
|
start = time.perf_counter()
|
2023-06-14 19:55:38 -07:00
|
|
|
batch: List[str] = []
|
|
|
|
max_prompt_len = 0
|
|
|
|
max_output_len = 0
|
|
|
|
for i in range(len(requests)):
|
|
|
|
prompt, prompt_len, output_len = requests[i]
|
|
|
|
# Add the prompt to the batch.
|
|
|
|
batch.append(prompt)
|
|
|
|
max_prompt_len = max(max_prompt_len, prompt_len)
|
|
|
|
max_output_len = max(max_output_len, output_len)
|
|
|
|
if len(batch) < max_batch_size and i != len(requests) - 1:
|
|
|
|
# Check if we can add more requests to the batch.
|
|
|
|
_, next_prompt_len, next_output_len = requests[i + 1]
|
2023-09-16 00:03:37 -07:00
|
|
|
if (max(max_prompt_len, next_prompt_len) +
|
|
|
|
max(max_output_len, next_output_len)) <= 2048:
|
2023-06-14 19:55:38 -07:00
|
|
|
# We can add more requests to the batch.
|
|
|
|
continue
|
|
|
|
|
|
|
|
# Generate the sequences.
|
2023-09-16 00:03:37 -07:00
|
|
|
input_ids = tokenizer(batch, return_tensors="pt",
|
|
|
|
padding=True).input_ids
|
2023-06-14 19:55:38 -07:00
|
|
|
llm_outputs = llm.generate(
|
|
|
|
input_ids=input_ids.cuda(),
|
2024-10-06 22:47:04 -07:00
|
|
|
do_sample=True,
|
2023-06-14 19:55:38 -07:00
|
|
|
num_return_sequences=n,
|
|
|
|
temperature=1.0,
|
|
|
|
top_p=1.0,
|
|
|
|
use_cache=True,
|
|
|
|
max_new_tokens=max_output_len,
|
|
|
|
)
|
|
|
|
# Include the decoding time.
|
|
|
|
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
|
|
|
pbar.update(len(batch))
|
|
|
|
|
|
|
|
# Clear the batch.
|
|
|
|
batch = []
|
|
|
|
max_prompt_len = 0
|
|
|
|
max_output_len = 0
|
2023-10-02 19:22:05 -07:00
|
|
|
end = time.perf_counter()
|
2023-06-14 19:55:38 -07:00
|
|
|
return end - start
|
|
|
|
|
|
|
|
|
2023-11-14 12:35:30 -08:00
|
|
|
def run_mii(
|
2024-11-04 14:32:16 -08:00
|
|
|
requests: List[SampleRequest],
|
2023-11-14 12:35:30 -08:00
|
|
|
model: str,
|
|
|
|
tensor_parallel_size: int,
|
|
|
|
output_len: int,
|
|
|
|
) -> float:
|
2024-03-28 17:33:52 -07:00
|
|
|
from mii import client, serve
|
|
|
|
llm = serve(model, tensor_parallel=tensor_parallel_size)
|
2024-11-04 14:32:16 -08:00
|
|
|
prompts = [request.prompt for request in requests]
|
2023-11-14 12:35:30 -08:00
|
|
|
|
|
|
|
start = time.perf_counter()
|
2024-03-28 17:33:52 -07:00
|
|
|
llm.generate(prompts, max_new_tokens=output_len)
|
2023-11-14 12:35:30 -08:00
|
|
|
end = time.perf_counter()
|
2024-03-28 17:33:52 -07:00
|
|
|
client = client(model)
|
|
|
|
client.terminate_server()
|
2023-11-14 12:35:30 -08:00
|
|
|
return end - start
|
|
|
|
|
|
|
|
|
2023-06-14 19:55:38 -07:00
|
|
|
def main(args: argparse.Namespace):
|
|
|
|
print(args)
|
|
|
|
random.seed(args.seed)
|
|
|
|
|
|
|
|
# Sample the requests.
|
2023-11-14 12:35:30 -08:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
|
|
|
if args.dataset is None:
|
|
|
|
# Synthesize a prompt with the given input length.
|
2024-10-22 17:45:35 -07:00
|
|
|
# As tokenizer may add additional tokens like BOS, we need to try
|
|
|
|
# different lengths to get the desired input length.
|
|
|
|
for i in range(-10, 10):
|
|
|
|
prompt = "hi " * (args.input_len + i)
|
|
|
|
tokenized_prompt = tokenizer(prompt).input_ids
|
|
|
|
if len(tokenized_prompt) == args.input_len:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"Failed to synthesize a prompt with {args.input_len} tokens.")
|
2024-11-04 14:32:16 -08:00
|
|
|
requests = [
|
|
|
|
SampleRequest(prompt=prompt,
|
|
|
|
prompt_len=args.input_len,
|
|
|
|
expected_output_len=args.output_len)
|
|
|
|
for _ in range(args.num_prompts)
|
|
|
|
]
|
2023-11-14 12:35:30 -08:00
|
|
|
else:
|
|
|
|
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
|
|
|
args.output_len)
|
2023-06-14 19:55:38 -07:00
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
if args.backend == "vllm":
|
2024-09-03 17:57:41 -07:00
|
|
|
if args.async_engine:
|
2024-10-22 17:40:38 -05:00
|
|
|
elapsed_time = uvloop.run(
|
|
|
|
run_vllm_async(
|
|
|
|
requests,
|
|
|
|
args.n,
|
|
|
|
AsyncEngineArgs.from_cli_args(args),
|
|
|
|
args.disable_frontend_multiprocessing,
|
|
|
|
))
|
2024-09-03 17:57:41 -07:00
|
|
|
else:
|
2024-10-22 17:40:38 -05:00
|
|
|
elapsed_time = run_vllm(requests, args.n,
|
|
|
|
EngineArgs.from_cli_args(args))
|
2023-06-14 19:55:38 -07:00
|
|
|
elif args.backend == "hf":
|
|
|
|
assert args.tensor_parallel_size == 1
|
2023-09-16 00:03:37 -07:00
|
|
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
2024-10-06 22:47:04 -07:00
|
|
|
args.hf_max_batch_size, args.trust_remote_code)
|
2023-11-14 12:35:30 -08:00
|
|
|
elif args.backend == "mii":
|
|
|
|
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
|
|
|
args.output_len)
|
2023-06-14 19:55:38 -07:00
|
|
|
else:
|
|
|
|
raise ValueError(f"Unknown backend: {args.backend}")
|
2024-11-04 14:32:16 -08:00
|
|
|
total_num_tokens = sum(request.prompt_len + request.expected_output_len
|
|
|
|
for request in requests)
|
|
|
|
total_output_tokens = sum(request.expected_output_len
|
|
|
|
for request in requests)
|
2023-06-04 12:52:41 -07:00
|
|
|
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
2024-10-23 12:47:20 -04:00
|
|
|
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
|
|
|
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
2023-05-28 03:20:05 -07:00
|
|
|
|
2024-05-16 10:02:56 -07:00
|
|
|
# Output JSON results if specified
|
|
|
|
if args.output_json:
|
|
|
|
results = {
|
|
|
|
"elapsed_time": elapsed_time,
|
|
|
|
"num_requests": len(requests),
|
|
|
|
"total_num_tokens": total_num_tokens,
|
|
|
|
"requests_per_second": len(requests) / elapsed_time,
|
|
|
|
"tokens_per_second": total_num_tokens / elapsed_time,
|
|
|
|
}
|
|
|
|
with open(args.output_json, "w") as f:
|
|
|
|
json.dump(results, f, indent=4)
|
|
|
|
|
2023-05-28 03:20:05 -07:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-06-20 19:00:13 -04:00
|
|
|
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
2023-09-16 00:03:37 -07:00
|
|
|
parser.add_argument("--backend",
|
|
|
|
type=str,
|
2023-11-14 12:35:30 -08:00
|
|
|
choices=["vllm", "hf", "mii"],
|
2023-06-17 03:07:40 -07:00
|
|
|
default="vllm")
|
2023-09-16 00:03:37 -07:00
|
|
|
parser.add_argument("--dataset",
|
|
|
|
type=str,
|
2023-11-14 12:35:30 -08:00
|
|
|
default=None,
|
2024-11-04 14:32:16 -08:00
|
|
|
help="Path to the dataset. The dataset is expected to "
|
|
|
|
"be a json in form of List[Dict[..., conversations: "
|
|
|
|
"List[Dict[..., value: <prompt_or_response>]]]]")
|
2023-11-14 12:35:30 -08:00
|
|
|
parser.add_argument("--input-len",
|
|
|
|
type=int,
|
|
|
|
default=None,
|
|
|
|
help="Input prompt length for each request")
|
|
|
|
parser.add_argument("--output-len",
|
|
|
|
type=int,
|
|
|
|
default=None,
|
|
|
|
help="Output length for each request. Overrides the "
|
|
|
|
"output length from the dataset.")
|
2023-09-16 00:03:37 -07:00
|
|
|
parser.add_argument("--n",
|
|
|
|
type=int,
|
|
|
|
default=1,
|
2023-05-28 03:20:05 -07:00
|
|
|
help="Number of generated sequences per prompt.")
|
2023-09-16 00:03:37 -07:00
|
|
|
parser.add_argument("--num-prompts",
|
|
|
|
type=int,
|
|
|
|
default=1000,
|
2023-05-28 03:20:05 -07:00
|
|
|
help="Number of prompts to process.")
|
2023-09-16 00:03:37 -07:00
|
|
|
parser.add_argument("--hf-max-batch-size",
|
|
|
|
type=int,
|
|
|
|
default=None,
|
2023-06-14 19:55:38 -07:00
|
|
|
help="Maximum batch size for HF backend.")
|
2024-05-16 10:02:56 -07:00
|
|
|
parser.add_argument(
|
|
|
|
'--output-json',
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help='Path to save the throughput results in JSON format.')
|
2024-09-03 17:57:41 -07:00
|
|
|
parser.add_argument("--async-engine",
|
|
|
|
action='store_true',
|
|
|
|
default=False,
|
|
|
|
help="Use vLLM async engine rather than LLM class.")
|
|
|
|
parser.add_argument("--disable-frontend-multiprocessing",
|
|
|
|
action='store_true',
|
|
|
|
default=False,
|
|
|
|
help="Disable decoupled async engine frontend.")
|
2024-10-22 17:40:38 -05:00
|
|
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
2023-05-28 03:20:05 -07:00
|
|
|
args = parser.parse_args()
|
2023-11-14 12:35:30 -08:00
|
|
|
if args.tokenizer is None:
|
|
|
|
args.tokenizer = args.model
|
|
|
|
if args.dataset is None:
|
|
|
|
assert args.input_len is not None
|
|
|
|
assert args.output_len is not None
|
|
|
|
else:
|
|
|
|
assert args.input_len is None
|
2023-06-28 09:46:58 -07:00
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
if args.backend == "vllm":
|
2023-06-14 19:55:38 -07:00
|
|
|
if args.hf_max_batch_size is not None:
|
|
|
|
raise ValueError("HF max batch size is only for HF backend.")
|
|
|
|
elif args.backend == "hf":
|
|
|
|
if args.hf_max_batch_size is None:
|
|
|
|
raise ValueError("HF max batch size is required for HF backend.")
|
2023-09-16 00:03:37 -07:00
|
|
|
if args.quantization is not None:
|
|
|
|
raise ValueError("Quantization is only for vLLM backend.")
|
2023-11-14 12:35:30 -08:00
|
|
|
elif args.backend == "mii":
|
|
|
|
if args.dtype != "auto":
|
|
|
|
raise ValueError("dtype must be auto for MII backend.")
|
|
|
|
if args.n != 1:
|
|
|
|
raise ValueError("n must be 1 for MII backend.")
|
|
|
|
if args.quantization is not None:
|
|
|
|
raise ValueError("Quantization is only for vLLM backend.")
|
|
|
|
if args.hf_max_batch_size is not None:
|
|
|
|
raise ValueError("HF max batch size is only for HF backend.")
|
|
|
|
if args.tokenizer != args.model:
|
|
|
|
raise ValueError("Tokenizer must be the same as the model for MII "
|
|
|
|
"backend.")
|
2023-05-28 03:20:05 -07:00
|
|
|
main(args)
|