[misc] benchmark_throughput : Add LoRA (#11267)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2024-12-19 02:43:16 -05:00 committed by GitHub
parent f26c4aeecb
commit 98356735ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,8 @@ import dataclasses
import json
import random
import time
from typing import List, Optional
from functools import cache
from typing import Dict, List, Optional, Tuple
import torch
import uvloop
@ -17,8 +18,11 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
@ -28,15 +32,17 @@ class SampleRequest:
Attributes:
prompt: The input text prompt for the model.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
lora_request: Optional LoRARequest specifying the LoRA to use.
"""
prompt: str
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[MultiModalDataDict] = None
lora_request: Optional[LoRARequest] = None
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str:
raise ValueError(f"Unsupported model {model}")
@cache
def lora_path_on_disk(lora_path: str) -> str:
return get_adapter_absolute_path(lora_path)
lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}
def get_random_lora_request(
args: argparse.Namespace
) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
global lora_tokenizer_cache
lora_id = random.randint(1, args.max_loras)
lora_request = LoRARequest(lora_name=str(lora_id),
lora_int_id=lora_id,
lora_path=lora_path_on_disk(args.lora_path))
if lora_id not in lora_tokenizer_cache:
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
return lora_request, lora_tokenizer_cache[lora_id]
def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:
dataset_path: str = args.dataset
num_requests: int = args.num_prompts
fixed_output_len: Optional[int] = args.output_len
@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
# Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = []
for data in dataset:
for data in tqdm(dataset,
total=len(filtered_dataset),
desc="sampling requests"):
if len(filtered_dataset) == num_requests:
break
@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
continue
prompt = _get_prompt_for_image_model(question=prompt, model=model)
request_tokenizer = tokenizer
lora_request: Optional[LoRARequest] = None
if args.enable_lora:
lora_request, lora_tokenizer = get_random_lora_request(args)
if lora_tokenizer:
request_tokenizer = lora_tokenizer
# Tokenize the prompts and completions.
prompt_token_ids = tokenizer(prompt).input_ids
completion_token_ids = tokenizer(completion).input_ids
prompt_token_ids = request_tokenizer(prompt).input_ids
completion_token_ids = request_tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
SampleRequest(prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=multi_modal_data))
multi_modal_data=multi_modal_data,
lora_request=lora_request))
return filtered_dataset
@ -146,14 +184,21 @@ def run_vllm(
ignore_eos=True,
max_tokens=request.expected_output_len,
))
lora_requests: Optional[List[LoRARequest]] = None
if engine_args.enable_lora:
lora_requests = [request.lora_request for request in requests]
use_beam_search = False
if not use_beam_search:
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
@ -185,6 +230,7 @@ async def run_vllm_async(
# Add the requests to the engine.
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
lora_requests: List[Optional[LoRARequest]] = []
for request in requests:
prompts.append(
TextPrompt(prompt=request.prompt,
@ -197,11 +243,16 @@ async def run_vllm_async(
ignore_eos=True,
max_tokens=request.expected_output_len,
))
lora_requests.append(request.lora_request)
generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
for i, (prompt, sp,
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
generator = llm.generate(prompt,
sp,
lora_request=lr,
request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
@ -297,6 +348,14 @@ def main(args: argparse.Namespace):
vocab_size = tokenizer.vocab_size
requests = []
for _ in range(args.num_prompts):
request_tokenizer = tokenizer
lora_request: Optional[LoRARequest] = None
if args.enable_lora:
lora_request, lora_tokenizer = get_random_lora_request(args)
if lora_tokenizer:
request_tokenizer = lora_tokenizer
# Synthesize a prompt with the given input length.
candidate_ids = [
random.randint(0, vocab_size - 1)
@ -305,8 +364,8 @@ def main(args: argparse.Namespace):
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for _ in range(5): # Max attempts to correct
candidate_prompt = tokenizer.decode(candidate_ids)
tokenized_len = len(tokenizer.encode(candidate_prompt))
candidate_prompt = request_tokenizer.decode(candidate_ids)
tokenized_len = len(request_tokenizer.encode(candidate_prompt))
if tokenized_len == args.input_len:
break
@ -323,7 +382,8 @@ def main(args: argparse.Namespace):
requests.append(
SampleRequest(prompt=candidate_prompt,
prompt_len=args.input_len,
expected_output_len=args.output_len))
expected_output_len=args.output_len,
lora_request=lora_request))
else:
requests = sample_requests(tokenizer, args)
@ -422,6 +482,14 @@ if __name__ == "__main__":
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
# LoRA
parser.add_argument(
"--lora-path",
type=str,
default=None,
help="Path to the lora adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
@ -431,6 +499,8 @@ if __name__ == "__main__":
assert args.output_len is not None
else:
assert args.input_len is None
if args.enable_lora:
assert args.lora_path is not None
if args.backend == "vllm":
if args.hf_max_batch_size is not None:
@ -440,6 +510,9 @@ if __name__ == "__main__":
raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.enable_lora is not None:
raise ValueError("LoRA benchmarking is only supported for vLLM"
" backend")
elif args.backend == "mii":
if args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
@ -452,4 +525,7 @@ if __name__ == "__main__":
if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII "
"backend.")
if args.enable_lora is not None:
raise ValueError("LoRA benchmarking is only supported for vLLM"
" backend")
main(args)