vllm/benchmarks/trace.py

117 lines
4.2 KiB
Python
Raw Normal View History

import pickle
import random
from typing import List, Tuple
import numpy as np
from cacheflow.sampling_params import SamplingParams
def generate_text_completion_requests(
dataset: str,
request_rate: float,
duration: int,
seed: int,
n1: float = 0.0,
n2: float = 0.0,
n3: float = 0.0,
n4: float = 0.0,
n6: float = 0.0,
n2_beam: float = 0.0,
n4_beam: float = 0.0,
n6_beam: float = 0.0,
n8_beam: float = 0.0,
max_seq_len: int = 2048,
time_quantum: int = 10,
) -> List[Tuple[float, List[int], SamplingParams]]:
random.seed(seed)
np.random.seed(seed)
# Generate timestamps for requests using Poisson distribution.
lam = request_rate * (time_quantum / 1000)
quantums_per_sec = 1000 / time_quantum
arrival_times = np.random.poisson(
lam=lam, size=int(duration * quantums_per_sec))
timestamps = []
for i, n in enumerate(arrival_times):
timestamps += [i * (time_quantum / 1000)] * n
# Load and shuffle the dataset.
num_requests = len(timestamps)
with open(dataset, 'rb') as f:
data = pickle.load(f)
filtered = []
for pair in data:
input_tokens, output_tokens = pair
input_len = len(input_tokens)
output_len = len(output_tokens)
# Filter out too long sequences.
if input_len + output_len < max_seq_len:
# Output tokens are not needed for the benchmark.
filtered.append((input_tokens, output_len))
data = []
while len(data) < num_requests:
data += filtered
data = data[:num_requests]
# Shuffle the data.
assert len(data) == len(timestamps)
random.shuffle(data)
random_sampling_params_dict = {
'temperature': 1.0,
'top_p': 1.0,
'use_beam_search': False,
'stop_token_ids': set(),
'num_logprobs': 0,
'context_window_size': None,
}
beam_search_params_dict = {
'temperature': 0.0,
'top_p': 1.0,
'use_beam_search': True,
'stop_token_ids': set(),
'num_logprobs': 0,
'context_window_size': None,
}
# Generate requests based on the sampling parameter ratio.
requests = []
assert n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam + n6_beam + n8_beam == 1.0
cum_sum = 0
for timestamp, pair in zip(timestamps, data):
input_tokens, output_len = pair
if cum_sum < n1 * num_requests:
sampling_params = SamplingParams(
n=1, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2) * num_requests:
sampling_params = SamplingParams(
n=2, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2 + n3) * num_requests:
sampling_params = SamplingParams(
n=3, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4) * num_requests:
sampling_params = SamplingParams(
n=4, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6) * num_requests:
sampling_params = SamplingParams(
n=6, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam) * num_requests:
sampling_params = SamplingParams(
n=2, max_num_steps=output_len, **beam_search_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam) * num_requests:
sampling_params = SamplingParams(
n=4, max_num_steps=output_len, **beam_search_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam + n6_beam) * num_requests:
sampling_params = SamplingParams(
n=6, max_num_steps=output_len, **beam_search_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam + n6_beam + n8_beam) * num_requests:
sampling_params = SamplingParams(
n=8, max_num_steps=output_len, **beam_search_params_dict)
else:
raise ValueError('Invalid request ratio.')
cum_sum += 1
requests.append((timestamp, input_tokens, sampling_params))
return requests