213 lines
6.6 KiB
Python
213 lines
6.6 KiB
Python
import argparse
|
|
import os
|
|
import pickle
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
|
|
SYSTEMS = [
|
|
'orca-constant',
|
|
'orca-power2',
|
|
'orca-oracle',
|
|
'cacheflow',
|
|
]
|
|
|
|
SYSTEM_TO_LABEL = {
|
|
'orca-constant': 'Orca (Max)',
|
|
'orca-power2': 'Orca (Pow2)',
|
|
'orca-oracle': 'Orca (Oracle)',
|
|
'cacheflow': 'KVFlow',
|
|
}
|
|
|
|
SYSTEM_TO_COLOR = {
|
|
'orca-constant': 'red',
|
|
'orca-power2': 'orange',
|
|
'orca-oracle': 'green',
|
|
'cacheflow': 'blue',
|
|
}
|
|
|
|
SYSTEM_TO_MARKER = {
|
|
'orca-constant': 'x',
|
|
'orca-power2': '^',
|
|
'orca-oracle': 's',
|
|
'cacheflow': 'o',
|
|
}
|
|
|
|
|
|
def get_results(save_dir: str) -> List[Dict[str, Any]]:
|
|
with open(os.path.join(save_dir, 'sequences.pkl'), 'rb') as f:
|
|
results = pickle.load(f)
|
|
return results
|
|
|
|
|
|
def get_request_rate(save_dir: str) -> float:
|
|
"""Get request rate from save_dir name."""
|
|
# Directory name format:
|
|
# .../req-rate-{req_rate}/seed-{seed}/duration-{duration}
|
|
save_dir = os.path.abspath(save_dir)
|
|
dir_names = save_dir.split('/')
|
|
|
|
request_rate = None
|
|
for dir_name in dir_names:
|
|
if dir_name.startswith('req-rate-'):
|
|
if request_rate is not None:
|
|
raise ValueError(f'Found multiple request rates in {save_dir}')
|
|
request_rate = float(dir_name.split('-')[-1])
|
|
if request_rate is None:
|
|
raise ValueError(f'Cannot find request rate in {save_dir}')
|
|
return request_rate
|
|
|
|
|
|
def get_model(save_dir: str) -> Tuple[str, int]:
|
|
save_dir = os.path.abspath(save_dir)
|
|
dir_names = save_dir.split('/')
|
|
|
|
model = None
|
|
for dir_name in dir_names:
|
|
if '-tp' in dir_name:
|
|
if model is not None:
|
|
raise ValueError(f'Found multiple models in {save_dir}')
|
|
model = dir_name.split('-tp')[0]
|
|
tp = int(dir_name.split('-tp')[-1])
|
|
if model is None:
|
|
raise ValueError(f'Cannot find model in {save_dir}')
|
|
return model, tp
|
|
|
|
|
|
def get_system(save_dir: str) -> str:
|
|
save_dir = os.path.abspath(save_dir)
|
|
dir_names = save_dir.split('/')
|
|
|
|
for dir_name in dir_names:
|
|
if dir_name.startswith('orca-'):
|
|
return dir_name
|
|
if dir_name == 'cacheflow':
|
|
return dir_name
|
|
raise ValueError(f'Cannot find system in {save_dir}')
|
|
|
|
|
|
def get_sampling(save_dir: str) -> str:
|
|
save_dir = os.path.abspath(save_dir)
|
|
dir_names = save_dir.split('/')
|
|
|
|
for dir_name in dir_names:
|
|
if dir_name.startswith('n'):
|
|
if dir_name.endswith('-beam'):
|
|
return dir_name
|
|
if dir_name[1:].isdigit():
|
|
return dir_name
|
|
raise ValueError(f'Cannot find sampling method in {save_dir}')
|
|
|
|
|
|
def plot_normalized_latency(
|
|
exp_dir: str,
|
|
duration: int,
|
|
seed: int,
|
|
warmup: int,
|
|
xlim: Optional[float],
|
|
ylim: Optional[float],
|
|
log_scale: bool,
|
|
format: str,
|
|
) -> None:
|
|
# Get leaf directories.
|
|
save_dirs = []
|
|
for root, dirs, files in os.walk(exp_dir):
|
|
if dirs:
|
|
continue
|
|
if 'sequences.pkl' not in files:
|
|
continue
|
|
if f'seed{seed}' not in root:
|
|
continue
|
|
if f'duration-{duration}' not in root:
|
|
continue
|
|
save_dirs.append(root)
|
|
|
|
# Plot normalized latency.
|
|
perf_per_system: Dict[str, Tuple[List[float], List[float]]] = {}
|
|
for save_dir in save_dirs:
|
|
per_seq_norm_latencies = []
|
|
results = get_results(save_dir)
|
|
for seq in results:
|
|
arrival_time = seq['arrival_time']
|
|
finish_time = seq['finish_time']
|
|
output_len = seq['output_len']
|
|
if arrival_time < warmup:
|
|
continue
|
|
latency = finish_time - arrival_time
|
|
norm_latency = latency / output_len
|
|
per_seq_norm_latencies.append(norm_latency)
|
|
|
|
request_rate = get_request_rate(save_dir)
|
|
normalized_latency = np.mean(per_seq_norm_latencies)
|
|
system_name = get_system(save_dir)
|
|
if system_name not in perf_per_system:
|
|
perf_per_system[system_name] = ([], [])
|
|
perf_per_system[system_name][0].append(request_rate)
|
|
perf_per_system[system_name][1].append(normalized_latency)
|
|
|
|
print('#seqs', len(per_seq_norm_latencies))
|
|
print(f'{save_dir}: {normalized_latency:.3f} s')
|
|
|
|
|
|
# Plot normalized latency.
|
|
plt.figure(figsize=(6, 4))
|
|
for system_name in reversed(SYSTEMS):
|
|
if system_name not in perf_per_system:
|
|
continue
|
|
# Sort by request rate.
|
|
request_rates, normalized_latencies = perf_per_system[system_name]
|
|
request_rates, normalized_latencies = zip(*sorted(zip(request_rates, normalized_latencies)))
|
|
label = SYSTEM_TO_LABEL[system_name]
|
|
color = SYSTEM_TO_COLOR[system_name]
|
|
marker = SYSTEM_TO_MARKER[system_name]
|
|
plt.plot(request_rates, normalized_latencies, label=label, color=color, marker=marker)
|
|
|
|
# plt.legend()
|
|
plt.xlabel('Request rate (req/s)', fontsize=12)
|
|
plt.ylabel('Normalized latency (s/token)', fontsize=12)
|
|
|
|
if log_scale:
|
|
plt.yscale('log')
|
|
if xlim is not None:
|
|
plt.xlim(left=0, right=xlim)
|
|
if ylim is not None:
|
|
if log_scale:
|
|
plt.ylim(top=ylim)
|
|
else:
|
|
plt.ylim(bottom=0, top=ylim)
|
|
|
|
handles, labels = plt.gca().get_legend_handles_labels()
|
|
handles = reversed(handles)
|
|
labels = reversed(labels)
|
|
|
|
plt.legend(
|
|
handles, labels,
|
|
ncol=4, fontsize=12, loc='upper center', bbox_to_anchor=(0.5, 1.15),
|
|
columnspacing=0.5, handletextpad=0.5, handlelength=1.5, frameon=False, borderpad=0)
|
|
|
|
# Save figure.
|
|
model, tp = get_model(exp_dir)
|
|
sampling = get_sampling(exp_dir)
|
|
figname = f'{model}-tp{tp}-{sampling}.{format}'
|
|
os.makedirs('./figures', exist_ok=True)
|
|
plt.savefig(os.path.join('figures', figname), bbox_inches='tight')
|
|
print(f'Saved figure to ./figures/{figname}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('exp_dir', type=str)
|
|
parser.add_argument('--duration', type=int, required=True)
|
|
parser.add_argument('--seed', type=int, default=0)
|
|
parser.add_argument('--warmup', type=int, default=60)
|
|
parser.add_argument('--xlim', type=float, required=False, default=None)
|
|
parser.add_argument('--ylim', type=float, required=False, default=None)
|
|
parser.add_argument('--log', action='store_true')
|
|
parser.add_argument('--format', choices=['png', 'pdf'], default='png')
|
|
args = parser.parse_args()
|
|
|
|
plot_normalized_latency(
|
|
args.exp_dir, args.duration, args.seed, args.warmup, args.xlim, args.ylim, args.log, args.format)
|