2024-08-23 00:32:02 +08:00
"""
Benchmark the efficiency of prefix caching .
This script allows you to benchmark the performance of
a model with and without prefix caching using either fixed prompts
or prompts sampled from the ShareGPT dataset .
Fixed example usage :
python benchmark_prefix_caching . py \
- - model meta - llama / Llama - 2 - 7 b - chat - hf \
- - enable - prefix - caching \
- - num - prompts 1 \
- - repeat - count 100
ShareGPT example usage :
# This command samples 20 prompts with input lengths
# between 128 and 256 tokens from the ShareGPT dataset,
# then replicates each prompt 5 times.
python benchmark_prefix_caching . py \
- - model meta - llama / Llama - 2 - 7 b - chat - hf \
- - dataset - path / path / to / ShareGPT_V3_unfiltered_cleaned_split . json \
- - enable - prefix - caching \
- - num - prompts 20 \
- - repeat - count 5 \
- - input - length - range 128 : 256
"""
import json
import random
2024-03-03 14:37:18 -08:00
import time
2024-08-23 00:32:02 +08:00
from typing import List , Optional , Tuple
from transformers import PreTrainedTokenizerBase
2024-03-03 14:37:18 -08:00
2024-03-25 23:59:47 +09:00
from vllm import LLM , SamplingParams
2024-06-20 19:00:13 -04:00
from vllm . utils import FlexibleArgumentParser
2024-03-03 14:37:18 -08:00
2024-08-23 00:32:02 +08:00
try :
from vllm . transformers_utils . tokenizer import get_tokenizer
except ImportError :
from backend_request_func import get_tokenizer
2024-03-16 00:36:29 -07:00
PROMPT = " You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table. \n # Table \n |Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes| \n |----|----|----|----|----|----|----|----| \n |J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan|| \n |J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy|| \n |J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan|| \n |J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer|| \n |F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan|| \n |F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj|| \n |F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj|| \n |F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan|| \n |F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam|| \n |F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar|| \n |M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan|| \n |M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar|| \n |M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy|| \n |M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan|| \n |M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan|| \n \n # Question \n What ' s the content in the (1,1) cells \n " # noqa: E501
2024-03-03 14:37:18 -08:00
2024-03-04 14:48:27 +08:00
def test_prefix ( llm = None , sampling_params = None , prompts = None ) :
2024-03-03 14:37:18 -08:00
start_time = time . time ( )
2024-03-04 14:48:27 +08:00
llm . generate ( prompts , sampling_params = sampling_params )
2024-03-03 14:37:18 -08:00
end_time = time . time ( )
print ( f " cost time { end_time - start_time } " )
2024-08-23 00:32:02 +08:00
def sample_requests (
dataset_path : str ,
num_requests : int ,
tokenizer : PreTrainedTokenizerBase ,
input_length_range : Tuple [ int , int ] ,
fixed_output_len : Optional [ int ] ,
) - > List [ Tuple [ str , int , int ] ] :
if fixed_output_len is not None and fixed_output_len < 4 :
raise ValueError ( " output_len too small " )
# Load the dataset.
with open ( dataset_path ) as f :
dataset = json . load ( f )
# Filter out the conversations with less than 2 turns.
dataset = [ data for data in dataset if len ( data [ " conversations " ] ) > = 2 ]
# Only keep the first two turns of each conversation.
dataset = [ ( data [ " conversations " ] [ 0 ] [ " value " ] ,
data [ " conversations " ] [ 1 ] [ " value " ] ) for data in dataset ]
# Shuffle the dataset.
random . shuffle ( dataset )
min_len , max_len = input_length_range
# Filter out sequences that are too long or too short
filtered_dataset : List [ Tuple [ str , int , int ] ] = [ ]
for i in range ( len ( dataset ) ) :
if len ( filtered_dataset ) == num_requests :
break
# Tokenize the prompts and completions.
prompt = dataset [ i ] [ 0 ]
prompt_token_ids = tokenizer ( prompt ) . input_ids
completion = dataset [ i ] [ 1 ]
completion_token_ids = tokenizer ( completion ) . input_ids
prompt_len = len ( prompt_token_ids )
output_len = len ( completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4 :
# Prune too short sequences.
continue
if min_len < = prompt_len < = max_len :
filtered_dataset . append ( ( prompt , prompt_len , output_len ) )
return filtered_dataset
def repeat_and_sort_requests ( requests : List [ Tuple [ str , int , int ] ] ,
repeat_count : int ,
sort : bool = False ) - > List [ str ] :
repeated_requests = requests * repeat_count
if sort :
repeated_requests . sort ( key = lambda x : x [ 1 ] )
else :
random . shuffle ( repeated_requests )
return [ req [ 0 ] for req in repeated_requests ]
2024-03-03 14:37:18 -08:00
def main ( args ) :
2024-08-23 00:32:02 +08:00
tokenizer = get_tokenizer ( args . model , trust_remote_code = True )
input_length_range = tuple ( map ( int , args . input_length_range . split ( ' : ' ) ) )
2024-10-04 14:58:57 -07:00
random . seed ( args . seed )
2024-08-23 00:32:02 +08:00
if args . dataset_path is not None :
print ( f " Start to sample { args . num_prompts } prompts "
" from {args.dataset_path} " )
filtered_datasets = sample_requests (
dataset_path = args . dataset_path ,
num_requests = args . num_prompts ,
tokenizer = tokenizer ,
input_length_range = input_length_range ,
fixed_output_len = args . output_len ,
)
else :
prompt_len = len ( tokenizer ( PROMPT ) . input_ids )
filtered_datasets = [ ( PROMPT , prompt_len , args . output_len )
] * args . num_prompts
2024-05-02 02:20:32 +08:00
llm = LLM ( model = args . model ,
2024-03-03 14:37:18 -08:00
tokenizer_mode = ' auto ' ,
trust_remote_code = True ,
enforce_eager = True ,
2024-05-02 02:20:32 +08:00
tensor_parallel_size = args . tensor_parallel_size ,
2024-03-03 14:37:18 -08:00
enable_prefix_caching = args . enable_prefix_caching )
2024-05-02 02:20:32 +08:00
sampling_params = SamplingParams ( temperature = 0 , max_tokens = args . output_len )
2024-03-03 14:37:18 -08:00
2024-08-23 00:32:02 +08:00
print ( " Testing filtered datasets " )
prompts = repeat_and_sort_requests ( filtered_datasets ,
repeat_count = args . repeat_count ,
sort = args . sort )
2024-03-03 14:37:18 -08:00
print ( " ------warm up------ " )
test_prefix (
llm = llm ,
2024-05-02 02:20:32 +08:00
prompts = prompts ,
2024-03-03 14:37:18 -08:00
sampling_params = sampling_params ,
)
print ( " ------start generating------ " )
test_prefix (
llm = llm ,
prompts = prompts ,
sampling_params = sampling_params ,
)
if __name__ == " __main__ " :
2024-06-20 19:00:13 -04:00
parser = FlexibleArgumentParser (
2024-08-23 00:32:02 +08:00
description =
' Benchmark the performance with or without automatic prefix caching. ' )
2024-05-02 02:20:32 +08:00
parser . add_argument ( ' --model ' ,
type = str ,
default = ' baichuan-inc/Baichuan2-13B-Chat ' )
2024-08-23 00:32:02 +08:00
parser . add_argument ( " --dataset-path " ,
type = str ,
default = None ,
help = " Path to the dataset. " )
2024-05-02 02:20:32 +08:00
parser . add_argument ( ' --tensor-parallel-size ' , ' -tp ' , type = int , default = 1 )
parser . add_argument ( ' --output-len ' , type = int , default = 10 )
2024-03-03 14:37:18 -08:00
parser . add_argument ( ' --enable-prefix-caching ' ,
action = ' store_true ' ,
help = ' enable prefix caching ' )
2024-08-23 00:32:02 +08:00
parser . add_argument ( ' --num-prompts ' ,
type = int ,
default = 1 ,
help = " Number of the prompts sampled from dataset " )
parser . add_argument ( ' --repeat-count ' ,
type = int ,
default = 100 ,
help = ' Number of times to repeat each prompt ' )
parser . add_argument ( ' --sort ' ,
action = ' store_true ' ,
help = ' Sort prompts by input length ' )
parser . add_argument ( ' --input-length-range ' ,
type = str ,
default = ' 128:256 ' ,
help = ' Range of input lengths for sampling prompts, '
' specified as " min:max " (e.g., " 128:256 " ). ' )
2024-10-04 16:07:54 -07:00
parser . add_argument ( " --seed " ,
type = int ,
default = 0 ,
2024-10-04 14:58:57 -07:00
help = ' Random seed for reproducibility ' )
2024-03-03 14:37:18 -08:00
args = parser . parse_args ( )
main ( args )