2024-08-23 00:32:02 +08:00
"""
Benchmark the efficiency of prefix caching .
This script allows you to benchmark the performance of
a model with and without prefix caching using either fixed prompts
or prompts sampled from the ShareGPT dataset .
Fixed example usage :
python benchmark_prefix_caching . py \
- - model meta - llama / Llama - 2 - 7 b - chat - hf \
- - enable - prefix - caching \
- - num - prompts 1 \
- - repeat - count 100
ShareGPT example usage :
# This command samples 20 prompts with input lengths
# between 128 and 256 tokens from the ShareGPT dataset,
# then replicates each prompt 5 times.
python benchmark_prefix_caching . py \
- - model meta - llama / Llama - 2 - 7 b - chat - hf \
- - dataset - path / path / to / ShareGPT_V3_unfiltered_cleaned_split . json \
- - enable - prefix - caching \
- - num - prompts 20 \
- - repeat - count 5 \
- - input - length - range 128 : 256
"""
2024-10-22 17:40:38 -05:00
import dataclasses
2024-08-23 00:32:02 +08:00
import json
import random
2024-03-03 14:37:18 -08:00
import time
2024-08-23 00:32:02 +08:00
from typing import List , Optional , Tuple
from transformers import PreTrainedTokenizerBase
2024-03-03 14:37:18 -08:00
2024-03-25 23:59:47 +09:00
from vllm import LLM , SamplingParams
2024-10-22 17:40:38 -05:00
from vllm . engine . arg_utils import EngineArgs
2024-06-20 19:00:13 -04:00
from vllm . utils import FlexibleArgumentParser
2024-03-03 14:37:18 -08:00
2024-08-23 00:32:02 +08:00
try :
from vllm . transformers_utils . tokenizer import get_tokenizer
except ImportError :
from backend_request_func import get_tokenizer
2024-03-16 00:36:29 -07:00
PROMPT = " You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table. \n # Table \n |Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes| \n |----|----|----|----|----|----|----|----| \n |J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan|| \n |J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy|| \n |J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan|| \n |J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer|| \n |F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan|| \n |F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj|| \n |F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj|| \n |F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan|| \n |F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam|| \n |F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar|| \n |M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan|| \n |M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar|| \n |M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy|| \n |M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan|| \n |M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan|| \n \n # Question \n What ' s the content in the (1,1) cells \n " # noqa: E501
2024-03-03 14:37:18 -08:00
2024-03-04 14:48:27 +08:00
def test_prefix ( llm = None , sampling_params = None , prompts = None ) :
2024-03-03 14:37:18 -08:00
start_time = time . time ( )
2024-03-04 14:48:27 +08:00
llm . generate ( prompts , sampling_params = sampling_params )
2024-03-03 14:37:18 -08:00
end_time = time . time ( )
print ( f " cost time { end_time - start_time } " )
2024-11-18 15:39:14 -08:00
@dataclasses.dataclass
class Request :
prompt : str
prompt_len : int
output_len : int
def sample_tokens ( tokenizer : PreTrainedTokenizerBase , length : int ) - > str :
vocab = tokenizer . get_vocab ( )
# Remove the special tokens.
vocab = {
k : v
for k , v in vocab . items ( ) if k not in tokenizer . all_special_ids
}
return random . choices ( list ( vocab . values ( ) ) , k = length )
def sample_requests_from_dataset (
2024-08-23 00:32:02 +08:00
dataset_path : str ,
num_requests : int ,
tokenizer : PreTrainedTokenizerBase ,
input_length_range : Tuple [ int , int ] ,
fixed_output_len : Optional [ int ] ,
2024-11-18 15:39:14 -08:00
) - > List [ Request ] :
2024-08-23 00:32:02 +08:00
if fixed_output_len is not None and fixed_output_len < 4 :
raise ValueError ( " output_len too small " )
# Load the dataset.
with open ( dataset_path ) as f :
dataset = json . load ( f )
# Filter out the conversations with less than 2 turns.
dataset = [ data for data in dataset if len ( data [ " conversations " ] ) > = 2 ]
# Only keep the first two turns of each conversation.
dataset = [ ( data [ " conversations " ] [ 0 ] [ " value " ] ,
data [ " conversations " ] [ 1 ] [ " value " ] ) for data in dataset ]
# Shuffle the dataset.
random . shuffle ( dataset )
min_len , max_len = input_length_range
2024-11-18 15:39:14 -08:00
assert min_len > = 0 and max_len > = min_len , " input_length_range too small "
2024-08-23 00:32:02 +08:00
# Filter out sequences that are too long or too short
2024-11-18 15:39:14 -08:00
filtered_requests : List [ Request ] = [ ]
2024-08-23 00:32:02 +08:00
for i in range ( len ( dataset ) ) :
2024-11-18 15:39:14 -08:00
if len ( filtered_requests ) == num_requests :
2024-08-23 00:32:02 +08:00
break
# Tokenize the prompts and completions.
2024-11-18 15:39:14 -08:00
prompt_token_ids = tokenizer ( dataset [ i ] [ 0 ] ) . input_ids
prompt = tokenizer . decode ( prompt_token_ids )
2024-08-23 00:32:02 +08:00
completion = dataset [ i ] [ 1 ]
completion_token_ids = tokenizer ( completion ) . input_ids
prompt_len = len ( prompt_token_ids )
2024-11-18 15:39:14 -08:00
output_len = ( len ( completion_token_ids )
if fixed_output_len is None else fixed_output_len )
2024-08-23 00:32:02 +08:00
if min_len < = prompt_len < = max_len :
2024-11-18 15:39:14 -08:00
filtered_requests . append ( Request ( prompt , prompt_len , output_len ) )
return filtered_requests
def sample_requests_from_random (
num_requests : int ,
tokenizer : PreTrainedTokenizerBase ,
input_length_range : Tuple [ int , int ] ,
fixed_output_len : Optional [ int ] ,
prefix_len : int ,
) - > List [ Request ] :
2024-08-23 00:32:02 +08:00
2024-11-18 15:39:14 -08:00
requests = [ ]
prefix_token_ids = sample_tokens ( tokenizer , prefix_len )
min_len , max_len = input_length_range
for i in range ( num_requests ) :
unique_part_token_ids = sample_tokens (
tokenizer ,
random . randint ( min_len - prefix_len , max_len - prefix_len ) )
prompt_token_ids = prefix_token_ids + unique_part_token_ids
prompt = tokenizer . decode ( prompt_token_ids )
prompt_len = len ( prompt_token_ids )
assert ( min_len < = prompt_len < = max_len
) , f " prompt_len { prompt_len } out of range { min_len } : { max_len } "
requests . append ( Request ( prompt , prompt_len , fixed_output_len ) )
return requests
2024-08-23 00:32:02 +08:00
2024-11-18 15:39:14 -08:00
def repeat_and_sort_requests ( requests : List [ Request ] ,
2024-08-23 00:32:02 +08:00
repeat_count : int ,
sort : bool = False ) - > List [ str ] :
repeated_requests = requests * repeat_count
if sort :
repeated_requests . sort ( key = lambda x : x [ 1 ] )
else :
random . shuffle ( repeated_requests )
2024-11-18 15:39:14 -08:00
return [ req . prompt for req in repeated_requests ]
2024-08-23 00:32:02 +08:00
2024-03-03 14:37:18 -08:00
def main ( args ) :
2024-08-23 00:32:02 +08:00
tokenizer = get_tokenizer ( args . model , trust_remote_code = True )
input_length_range = tuple ( map ( int , args . input_length_range . split ( ' : ' ) ) )
2024-10-04 14:58:57 -07:00
random . seed ( args . seed )
2024-08-23 00:32:02 +08:00
if args . dataset_path is not None :
2024-11-18 15:39:14 -08:00
if args . prefix_len > 0 :
raise ValueError ( " prefix-len is not supported when "
" dataset-path is provided. " )
print ( f " Start to sample { args . num_prompts } prompts "
2024-11-07 17:34:44 -08:00
f " from { args . dataset_path } " )
2024-11-18 15:39:14 -08:00
filtered_requests = sample_requests_from_dataset (
2024-08-23 00:32:02 +08:00
dataset_path = args . dataset_path ,
num_requests = args . num_prompts ,
tokenizer = tokenizer ,
input_length_range = input_length_range ,
fixed_output_len = args . output_len ,
)
else :
2024-11-18 15:39:14 -08:00
print ( f " Start to sample { args . num_prompts } prompts from random " )
filtered_requests = sample_requests_from_random (
num_requests = args . num_prompts ,
tokenizer = tokenizer ,
input_length_range = input_length_range ,
fixed_output_len = args . output_len ,
prefix_len = args . prefix_len ,
)
# Print some helpful stats of the requests.
print ( f " Sampled { len ( filtered_requests ) } requests. " )
prompt_lens = [ req . prompt_len for req in filtered_requests ]
print ( f " Average input length: { sum ( prompt_lens ) / len ( prompt_lens ) } " )
print ( f " P50 input length: { sorted ( prompt_lens ) [ len ( prompt_lens ) / / 2 ] } " )
print ( f " Min Prompt Length: { min ( prompt_lens ) } " )
print ( f " Max Prompt Length: { max ( prompt_lens ) } " )
2024-08-23 00:32:02 +08:00
2024-10-22 17:40:38 -05:00
engine_args = EngineArgs . from_cli_args ( args )
llm = LLM ( * * dataclasses . asdict ( engine_args ) )
2024-03-03 14:37:18 -08:00
2024-05-02 02:20:32 +08:00
sampling_params = SamplingParams ( temperature = 0 , max_tokens = args . output_len )
2024-03-03 14:37:18 -08:00
2024-11-18 15:39:14 -08:00
print ( " Testing filtered requests " )
prompts = repeat_and_sort_requests ( filtered_requests ,
2024-08-23 00:32:02 +08:00
repeat_count = args . repeat_count ,
sort = args . sort )
2024-03-03 14:37:18 -08:00
print ( " ------start generating------ " )
test_prefix (
llm = llm ,
prompts = prompts ,
sampling_params = sampling_params ,
)
if __name__ == " __main__ " :
2024-06-20 19:00:13 -04:00
parser = FlexibleArgumentParser (
2024-08-23 00:32:02 +08:00
description =
' Benchmark the performance with or without automatic prefix caching. ' )
parser . add_argument ( " --dataset-path " ,
type = str ,
default = None ,
help = " Path to the dataset. " )
2024-05-02 02:20:32 +08:00
parser . add_argument ( ' --output-len ' , type = int , default = 10 )
2024-08-23 00:32:02 +08:00
parser . add_argument ( ' --num-prompts ' ,
type = int ,
2024-11-18 15:39:14 -08:00
required = True ,
2024-08-23 00:32:02 +08:00
help = " Number of the prompts sampled from dataset " )
parser . add_argument ( ' --repeat-count ' ,
type = int ,
2024-11-18 15:39:14 -08:00
default = 1 ,
2024-08-23 00:32:02 +08:00
help = ' Number of times to repeat each prompt ' )
parser . add_argument ( ' --sort ' ,
action = ' store_true ' ,
help = ' Sort prompts by input length ' )
parser . add_argument ( ' --input-length-range ' ,
type = str ,
2024-11-18 15:39:14 -08:00
required = True ,
2024-08-23 00:32:02 +08:00
help = ' Range of input lengths for sampling prompts, '
' specified as " min:max " (e.g., " 128:256 " ). ' )
2024-11-18 15:39:14 -08:00
parser . add_argument (
" --prefix-len " ,
type = int ,
default = 0 ,
help = " Specifies the length of a common prefix to be "
" added to the input prompt. The input-length-range will "
" subtract this length when filtering prompts. Only used "
" when dataset-path is not provided. " ,
)
2024-10-22 17:40:38 -05:00
parser = EngineArgs . add_cli_args ( parser )
2024-03-03 14:37:18 -08:00
args = parser . parse_args ( )
main ( args )