Collect system stats in scheduler & Add scripts for experiments (#30)

This commit is contained in:
Woosuk Kwon 2023-04-12 15:03:49 -07:00 committed by GitHub
parent e3cec88aa5
commit 84eee24e20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 830 additions and 3 deletions

4
.gitignore vendored
View File

@ -4,3 +4,7 @@
*.eggs/
*.so
build/
*.pkl
*.png
**/log.txt

View File

@ -37,6 +37,7 @@ def main(args: argparse.Namespace):
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,

View File

@ -0,0 +1,289 @@
import argparse
import logging
import os
import pickle
import time
from typing import List
from tqdm import tqdm
from transformers import AutoConfig
from benchmark.trace import generate_text_completion_requests
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
logger = logging.getLogger(__name__)
def main(args: argparse.Namespace):
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
address='local',
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
# Create a server.
server = Server(
model=args.model,
model_path=args.model_path,
use_dummy_weights=args.use_dummy_weights,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
collect_stats=True,
do_memory_analysis=args.do_memory_analysis,
)
# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
# Generate requests.
requests = generate_text_completion_requests(
args.dataset,
args.request_rate,
args.duration,
args.seed,
args.n1,
args.n2,
args.n3,
args.n4,
args.n6,
args.n2_beam,
args.n4_beam,
args.n6_beam,
args.n8_beam,
)
# Warm up.
logger.info('Warming up.')
num_warmup_requests = 8
warmup_input_len = 8
warmup_output_len = 32
warmup_sampling_params = SamplingParams(
n=1,
temperature=1.0,
top_p=0.99,
max_num_steps=warmup_output_len,
use_beam_search=False,
stop_token_ids=set(),
num_logprobs=0,
context_window_size=None,
)
for _ in range(num_warmup_requests):
frontend._add_query([0] * warmup_input_len, warmup_sampling_params)
server.add_sequence_groups(frontend.get_inputs())
while True:
server.step()
if not server.has_unfinished_requests():
break
# Start benchmarking.
logger.info('Start benchmarking.')
# Initialize tqdm.
pbar = tqdm(total=len(requests), desc='Finished requests')
finished = []
server.scheduler.reset_stats()
start_time = time.time()
while True:
now = time.time()
if args.timeout is not None and now - start_time > args.timeout:
logger.info('Timeout. Stop benchmarking.')
break
while requests:
if requests[0][0] <= now - start_time:
request_time, input_tokens, sampling_params = requests.pop(0)
frontend._add_query(
input_tokens, sampling_params, arrival_time=start_time + request_time)
else:
break
server.add_sequence_groups(frontend.get_inputs())
updated_seq_groups = server.step()
now = time.time()
for seq_group in updated_seq_groups:
if not seq_group.is_finished():
continue
arrival_time = seq_group.arrival_time
finish_time = now
for seq in seq_group.get_seqs():
seq_len = seq.get_len()
output_len = seq_len - seq.prompt_len
finished.append({
'group_id': seq_group.group_id,
'seq_id': seq.seq_id,
'arrival_time': arrival_time,
'finish_time': finish_time,
'prompt_len': seq.prompt_len,
'output_len': output_len,
})
pbar.update(1)
if not (requests or server.has_unfinished_requests()):
break
pbar.close()
logger.info('Finish benchmarking. Saving stats.')
server.scheduler.save_stats(args.output_dir)
with open(os.path.join(args.output_dir, 'sequences.pkl'), 'wb') as f:
pickle.dump(finished, f)
logger.info('Done.')
def get_model_name(model: str) -> str:
OPT_MODELS = [
'opt-125m',
'opt-350m',
'opt-1.3b',
'opt-2.7b',
'opt-6.7b',
'opt-13b',
'opt-30b',
'opt-66b',
'opt-175b',
]
for opt_model in OPT_MODELS:
if opt_model in model:
return opt_model
config = AutoConfig.from_pretrained(model)
assert config.model_type == 'llama'
hidden_size = config.hidden_size
if hidden_size == 4096:
return 'llama-7b'
elif hidden_size == 5120:
return 'llama-13b'
elif hidden_size == 6656:
return 'llama-30b'
elif hidden_size == 8192:
return 'llama-65b'
else:
raise ValueError(f'Unknown model: {model}')
def get_dataset_name(dataset: str) -> str:
if 'sharegpt' in dataset.lower():
return 'sharegpt'
elif 'alpaca' in dataset.lower():
return 'alpaca'
else:
raise ValueError(f'Unknown dataset: {dataset}')
def get_sampling_dir_name(
n1: float,
n2: float,
n3: float,
n4: float,
n6: float,
n2_beam: float,
n4_beam: float,
n6_beam: float,
n8_beam: float,
) -> str:
method = ''
if n1 > 0.0:
method = 'n1' if n1 == 1.0 else method + f'n1-{n1}-'
if n2 > 0.0:
method = 'n2' if n2 == 1.0 else method + f'n2-{n2}-'
if n3 > 0.0:
method = 'n3' if n3 == 1.0 else method + f'n3-{n3}-'
if n4 > 0.0:
method = 'n4' if n4 == 1.0 else method + f'n4-{n4}-'
if n6 > 0.0:
method = 'n6' if n6 == 1.0 else method + f'n6-{n6}-'
if n2_beam > 0.0:
method = 'n2-beam' if n2_beam == 1.0 else method + f'n2-beam-{n2_beam}-'
if n4_beam > 0.0:
method = 'n4-beam' if n4_beam == 1.0 else method + f'n4-beam-{n4_beam}-'
if n6_beam > 0.0:
method = 'n6-beam' if n6_beam == 1.0 else method + f'n6-beam-{n6_beam}-'
if n8_beam > 0.0:
method = 'n8-beam' if n8_beam == 1.0 else method + f'n8-beam-{n8_beam}-'
return method[:-1] if method.endswith('-') else method
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
parser = add_server_arguments(parser)
parser.add_argument('--output-dir', type=str, help='path to output directory', default=None)
parser.add_argument('--dataset', type=str, help='path to dataset', required=True)
parser.add_argument('--request-rate', type=float, help='reqs/sec', required=True)
parser.add_argument('--duration', type=int, help='duration in seconds', required=True)
parser.add_argument('--do-memory-analysis', action='store_true',
help='do memory analysis (This will lower the throughput. Use this only for analysis.)')
parser.add_argument('--timeout', type=int, help='time out in seconds', default=None)
parser.add_argument('--n1', type=float, help='ratio of requests with n=1', default=0.0)
parser.add_argument('--n2', type=float, help='ratio of requests with n=2', default=0.0)
parser.add_argument('--n3', type=float, help='ratio of requests with n=3', default=0.0)
parser.add_argument('--n4', type=float, help='ratio of requests with n=4', default=0.0)
parser.add_argument('--n6', type=float, help='ratio of requests with n=6', default=0.0)
parser.add_argument('--n2-beam', type=float, help='ratio of requests with n=2 & beam search', default=0.0)
parser.add_argument('--n4-beam', type=float, help='ratio of requests with n=4 & beam search', default=0.0)
parser.add_argument('--n6-beam', type=float, help='ratio of requests with n=6 & beam search', default=0.0)
parser.add_argument('--n8-beam', type=float, help='ratio of requests with n=8 & beam search', default=0.0)
args = parser.parse_args()
if args.n1 + args.n2 + args.n3 + args.n4 + args.n6 + args.n2_beam + args.n4_beam + args.n6_beam + args.n8_beam != 1.0:
raise ValueError('The ratios of requests must sum to 1.')
model_name = get_model_name(args.model)
dataset_name = get_dataset_name(args.dataset)
if 'opt' in model_name:
if 'opt' not in args.dataset.lower():
raise ValueError(f'OPT models can only be used with OPT datasets.')
elif 'llama' in model_name:
if 'llama' not in args.dataset.lower():
raise ValueError(f'Llama models can only be used with Llama datasets.')
dataset_name = 'sharegpt' if 'sharegpt' in args.dataset else 'alpaca'
sample_dir = get_sampling_dir_name(
args.n1, args.n2, args.n3, args.n4, args.n6, args.n2_beam, args.n4_beam, args.n6_beam, args.n8_beam)
if args.output_dir is None:
args.output_dir = os.path.join(
'../exp',
dataset_name,
f'{model_name}-tp{args.tensor_parallel_size}',
sample_dir,
'cacheflow',
f'req-rate-{args.request_rate}',
f'seed{args.seed}',
f'duration-{args.duration}',
)
os.makedirs(args.output_dir, exist_ok=True)
# Set up logging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
handlers=[
logging.StreamHandler(),
logging.FileHandler(os.path.join(args.output_dir, 'log.txt')),
],
)
logger.info(args)
main(args)

116
benchmark/trace.py Normal file
View File

@ -0,0 +1,116 @@
import pickle
import random
from typing import List, Tuple
import numpy as np
from cacheflow.sampling_params import SamplingParams
def generate_text_completion_requests(
dataset: str,
request_rate: float,
duration: int,
seed: int,
n1: float = 0.0,
n2: float = 0.0,
n3: float = 0.0,
n4: float = 0.0,
n6: float = 0.0,
n2_beam: float = 0.0,
n4_beam: float = 0.0,
n6_beam: float = 0.0,
n8_beam: float = 0.0,
max_seq_len: int = 2048,
time_quantum: int = 10,
) -> List[Tuple[float, List[int], SamplingParams]]:
random.seed(seed)
np.random.seed(seed)
# Generate timestamps for requests using Poisson distribution.
lam = request_rate * (time_quantum / 1000)
quantums_per_sec = 1000 / time_quantum
arrival_times = np.random.poisson(
lam=lam, size=int(duration * quantums_per_sec))
timestamps = []
for i, n in enumerate(arrival_times):
timestamps += [i * (time_quantum / 1000)] * n
# Load and shuffle the dataset.
num_requests = len(timestamps)
with open(dataset, 'rb') as f:
data = pickle.load(f)
filtered = []
for pair in data:
input_tokens, output_tokens = pair
input_len = len(input_tokens)
output_len = len(output_tokens)
# Filter out too long sequences.
if input_len + output_len < max_seq_len:
# Output tokens are not needed for the benchmark.
filtered.append((input_tokens, output_len))
data = []
while len(data) < num_requests:
data += filtered
data = data[:num_requests]
# Shuffle the data.
assert len(data) == len(timestamps)
random.shuffle(data)
random_sampling_params_dict = {
'temperature': 1.0,
'top_p': 1.0,
'use_beam_search': False,
'stop_token_ids': set(),
'num_logprobs': 0,
'context_window_size': None,
}
beam_search_params_dict = {
'temperature': 0.0,
'top_p': 1.0,
'use_beam_search': True,
'stop_token_ids': set(),
'num_logprobs': 0,
'context_window_size': None,
}
# Generate requests based on the sampling parameter ratio.
requests = []
assert n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam + n6_beam + n8_beam == 1.0
cum_sum = 0
for timestamp, pair in zip(timestamps, data):
input_tokens, output_len = pair
if cum_sum < n1 * num_requests:
sampling_params = SamplingParams(
n=1, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2) * num_requests:
sampling_params = SamplingParams(
n=2, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2 + n3) * num_requests:
sampling_params = SamplingParams(
n=3, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4) * num_requests:
sampling_params = SamplingParams(
n=4, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6) * num_requests:
sampling_params = SamplingParams(
n=6, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam) * num_requests:
sampling_params = SamplingParams(
n=2, max_num_steps=output_len, **beam_search_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam) * num_requests:
sampling_params = SamplingParams(
n=4, max_num_steps=output_len, **beam_search_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam + n6_beam) * num_requests:
sampling_params = SamplingParams(
n=6, max_num_steps=output_len, **beam_search_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam + n6_beam + n8_beam) * num_requests:
sampling_params = SamplingParams(
n=8, max_num_steps=output_len, **beam_search_params_dict)
else:
raise ValueError('Invalid request ratio.')
cum_sum += 1
requests.append((timestamp, input_tokens, sampling_params))
return requests

View File

@ -241,3 +241,9 @@ class BlockSpaceManager:
def get_block_table(self, seq: Sequence) -> List[int]:
block_table = self.block_tables[seq.seq_id]
return [block.block_number for block in block_table]
def get_num_free_gpu_blocks(self) -> int:
return self.gpu_allocator.get_num_free_blocks()
def get_num_free_cpu_blocks(self) -> int:
return self.cpu_allocator.get_num_free_blocks()

View File

@ -1,6 +1,8 @@
import enum
import os
import pickle
import time
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.policy import PolicyFactory
@ -34,12 +36,18 @@ class Scheduler:
num_gpu_blocks: int,
num_cpu_blocks: int,
max_num_batched_tokens: int,
max_num_sequences: int,
collect_stats: bool,
do_memory_analysis: bool = False,
) -> None:
self.controllers = controllers
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_sequences = max_num_sequences
self.collect_stats = collect_stats
self.do_memory_analysis = do_memory_analysis
# Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name='fcfs')
@ -61,6 +69,9 @@ class Scheduler:
# Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = []
# Performance-related statistics.
self.stats = Stats(num_gpu_blocks, num_cpu_blocks)
def add_sequence_groups(
self,
seq_groups: List[Tuple[SequenceGroup, SamplingParams]],
@ -123,6 +134,12 @@ class Scheduler:
if not self.block_manager.can_swap_in(seq_group):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
if len(self.running) + num_seqs > self.max_num_sequences:
break
seq_group = self.swapped.pop(0)
self._swap_in(seq_group, blocks_to_swap_in)
self._append(seq_group, blocks_to_copy)
@ -156,12 +173,68 @@ class Scheduler:
> self.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
if len(self.running) + num_seqs > self.max_num_sequences:
break
seq_group = self.waiting.pop(0)
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
prompt_group_ids.append(seq_group.group_id)
if self.collect_stats:
if self.running or blocks_to_swap_in or blocks_to_swap_out:
self.stats.timestamps.append(now - self.stats.start_time)
self.stats.input_lens.append(num_batched_tokens)
self.stats.swap_out_lens.append(len(blocks_to_swap_out) * self.block_size)
self.stats.swap_in_lens.append(len(blocks_to_swap_in) * self.block_size)
self.stats.num_preemption.append(len(preempted))
self.stats.num_swapped.append(len(self.swapped))
self.stats.num_running.append(len(self.running))
self.stats.num_waiting.append(len(self.waiting))
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
num_used_gpu_blocks = self.num_gpu_blocks - num_free_gpu_blocks
self.stats.gpu_cache_usage.append(num_used_gpu_blocks / self.num_gpu_blocks)
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
num_used_cpu_blocks = self.num_cpu_blocks - num_free_cpu_blocks
self.stats.cpu_cache_usage.append(num_used_cpu_blocks / self.num_cpu_blocks)
if self.do_memory_analysis:
block_tables = self.block_manager.block_tables
num_logical_blocks = 0
num_logical_tokens = 0
num_physical_blocks = 0
num_physical_tokens = 0
physical_block_numbers = set()
num_reserved_tokens = 0
for seq_group in self.running:
group_id = seq_group.group_id
sampling_params = self.sampling_params[group_id]
max_num_steps = sampling_params.max_num_steps
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
num_logical_blocks += len(seq.logical_token_blocks)
num_logical_tokens += seq.get_len()
seq_id = seq.seq_id
block_table = block_tables[seq_id]
for i, block in enumerate(block_table):
if block.block_number in physical_block_numbers:
continue
physical_block_numbers.add(block.block_number)
num_physical_blocks += 1
num_physical_tokens += seq.logical_token_blocks[i].num_tokens
assert num_physical_blocks == num_used_gpu_blocks
self.stats.num_logical_blocks.append(num_logical_blocks)
self.stats.num_logical_tokens.append(num_logical_tokens)
self.stats.num_physical_blocks.append(num_physical_blocks)
self.stats.num_physical_tokens.append(num_physical_tokens)
self.stats.num_reserved_tokens.append(num_reserved_tokens)
return (blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
@ -381,3 +454,75 @@ class Scheduler:
blocks_to_swap_out.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED
def reset_stats(self) -> None:
self.stats.reset(self.num_gpu_blocks, self.num_cpu_blocks)
def save_stats(
self,
output_dir: str,
) -> None:
assert self.collect_stats, 'Statistics collection is disabled.'
self.stats.save(output_dir)
class Stats:
def __init__(
self,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
self.start_time: float = time.time()
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.timestamps: List[float] = []
self.input_lens: List[int] = []
self.swap_out_lens: List[int] = []
self.swap_in_lens: List[int] = []
self.num_preemption: List[int] = []
self.num_waiting: List[int] = []
self.num_running: List[int] = []
self.num_swapped: List[int] = []
self.gpu_cache_usage: List[float] = []
self.cpu_cache_usage: List[float] = []
self.num_logical_blocks: List[int] = []
self.num_logical_tokens: List[int] = []
self.num_physical_blocks: List[int] = []
self.num_physical_tokens: List[int] = []
self.num_reserved_tokens: List[int] = []
def reset(
self,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
self.__init__(num_gpu_blocks, num_cpu_blocks)
def to_dict(self) -> Dict[str, Any]:
return {
'start_time': self.start_time,
'num_gpu_blocks': self.num_gpu_blocks,
'num_cpu_blocks': self.num_cpu_blocks,
'timestamps': self.timestamps,
'input_lens': self.input_lens,
'swap_out_lens': self.swap_out_lens,
'swap_in_lens': self.swap_in_lens,
'num_preemption': self.num_preemption,
'num_waiting': self.num_waiting,
'num_running': self.num_running,
'num_swapped': self.num_swapped,
'gpu_cache_usage': self.gpu_cache_usage,
'cpu_cache_usage': self.cpu_cache_usage,
'num_logical_blocks': self.num_logical_blocks,
'num_logical_tokens': self.num_logical_tokens,
'num_physical_blocks': self.num_physical_blocks,
'num_physical_tokens': self.num_physical_tokens,
'num_reserved_tokens': self.num_reserved_tokens,
}
def save(self, output_dir: str) -> None:
with open(os.path.join(output_dir, 'stats.pkl'), 'wb') as f:
pickle.dump(self.to_dict(), f)

View File

@ -24,12 +24,15 @@ class Server:
seed: int,
swap_space: int,
max_num_batched_tokens: int,
max_num_sequences: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
gpu_memory: int,
cpu_memory: int,
collect_stats: bool = False,
do_memory_analysis: bool = False,
):
self.num_nodes = num_nodes
self.num_devices_per_node = num_devices_per_node
@ -79,6 +82,9 @@ class Server:
num_gpu_blocks=self.num_gpu_blocks,
num_cpu_blocks=self.num_cpu_blocks,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
collect_stats=collect_stats,
do_memory_analysis=do_memory_analysis,
)
# Connect the controllers.
for i in range(len(self.controllers) - 1):
@ -180,6 +186,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
return parser

View File

@ -39,8 +39,10 @@ class SimpleFrontend:
self,
token_ids: List[int],
sampling_params: SamplingParams,
arrival_time: Optional[float] = None,
) -> None:
arrival_time = time.time()
if arrival_time is None:
arrival_time = time.time()
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)

View File

@ -28,6 +28,7 @@ class Sequence:
# Initialize the logical token blocks with the given token ids.
self.add(token_ids)
self.prompt_len = len(token_ids)
self.status = SequenceStatus.WAITING
self.output_logprobs: List[Dict[int, float]] = []
self.cumulative_logprobs = 0.0

View File

@ -0,0 +1,203 @@
import argparse
import os
import pickle
from typing import Any, Dict, List, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
SYSTEMS = [
'orca-constant',
'orca-power2',
'orca-oracle',
'cacheflow',
]
SYSTEM_TO_LABEL = {
'orca-constant': 'Orca (Max)',
'orca-power2': 'Orca (Next power of 2)',
'orca-oracle': 'Orca (Oracle)',
'cacheflow': 'CacheFlow',
}
SYSTEM_TO_COLOR = {
'orca-constant': 'red',
'orca-power2': 'orange',
'orca-oracle': 'green',
'cacheflow': 'blue',
}
SYSTEM_TO_MARKER = {
'orca-constant': 'x',
'orca-power2': '^',
'orca-oracle': 's',
'cacheflow': 'o',
}
def get_results(save_dir: str) -> List[Dict[str, Any]]:
with open(os.path.join(save_dir, 'sequences.pkl'), 'rb') as f:
results = pickle.load(f)
return results
def get_request_rate(save_dir: str) -> float:
"""Get request rate from save_dir name."""
# Directory name format:
# .../req-rate-{req_rate}/seed-{seed}/duration-{duration}
save_dir = os.path.abspath(save_dir)
dir_names = save_dir.split('/')
request_rate = None
for dir_name in dir_names:
if dir_name.startswith('req-rate-'):
if request_rate is not None:
raise ValueError(f'Found multiple request rates in {save_dir}')
request_rate = float(dir_name.split('-')[-1])
if request_rate is None:
raise ValueError(f'Cannot find request rate in {save_dir}')
return request_rate
def get_model(save_dir: str) -> Tuple[str, int]:
save_dir = os.path.abspath(save_dir)
dir_names = save_dir.split('/')
model = None
for dir_name in dir_names:
if '-tp' in dir_name:
if model is not None:
raise ValueError(f'Found multiple models in {save_dir}')
model = dir_name.split('-tp')[0]
tp = int(dir_name.split('-tp')[-1])
if model is None:
raise ValueError(f'Cannot find model in {save_dir}')
return model, tp
def get_system(save_dir: str) -> str:
save_dir = os.path.abspath(save_dir)
dir_names = save_dir.split('/')
for dir_name in dir_names:
if dir_name.startswith('orca-'):
return dir_name
if dir_name == 'cacheflow':
return dir_name
raise ValueError(f'Cannot find system in {save_dir}')
def get_sampling(save_dir: str) -> str:
save_dir = os.path.abspath(save_dir)
dir_names = save_dir.split('/')
for dir_name in dir_names:
if dir_name.startswith('n'):
if dir_name.endswith('-beam'):
return dir_name
if dir_name[1:].isdigit():
return dir_name
raise ValueError(f'Cannot find sampling method in {save_dir}')
def plot_normalized_latency(
exp_dir: str,
duration: int,
seed: int,
warmup: int,
xlim: Optional[float],
ylim: Optional[float],
log_scale: bool,
format: str,
) -> None:
# Get leaf directories.
save_dirs = []
for root, dirs, files in os.walk(exp_dir):
if dirs:
continue
if 'sequences.pkl' not in files:
continue
if f'seed{seed}' not in root:
continue
if f'duration-{duration}' not in root:
continue
save_dirs.append(root)
# Plot normalized latency.
perf_per_system: Dict[str, Tuple[List[float], List[float]]] = {}
for save_dir in save_dirs:
per_seq_norm_latencies = []
results = get_results(save_dir)
for seq in results:
arrival_time = seq['arrival_time']
finish_time = seq['finish_time']
output_len = seq['output_len']
if arrival_time < warmup:
continue
latency = finish_time - arrival_time
norm_latency = latency / output_len
per_seq_norm_latencies.append(norm_latency)
request_rate = get_request_rate(save_dir)
normalized_latency = np.mean(per_seq_norm_latencies)
system_name = get_system(save_dir)
if system_name not in perf_per_system:
perf_per_system[system_name] = ([], [])
perf_per_system[system_name][0].append(request_rate)
perf_per_system[system_name][1].append(normalized_latency)
print('#seqs', len(per_seq_norm_latencies))
print(f'{save_dir}: {normalized_latency:.3f} s')
# Plot normalized latency.
plt.figure(figsize=(6, 4))
for system_name in reversed(SYSTEMS):
if system_name not in perf_per_system:
continue
# Sort by request rate.
request_rates, normalized_latencies = perf_per_system[system_name]
request_rates, normalized_latencies = zip(*sorted(zip(request_rates, normalized_latencies)))
label = SYSTEM_TO_LABEL[system_name]
color = SYSTEM_TO_COLOR[system_name]
marker = SYSTEM_TO_MARKER[system_name]
plt.plot(request_rates, normalized_latencies, label=label, color=color, marker=marker)
# plt.legend()
plt.xlabel('Request rate (req/s)', fontsize=12)
plt.ylabel('Normalized latency (s/token)', fontsize=12)
if log_scale:
plt.yscale('log')
if xlim is not None:
plt.xlim(left=0, right=xlim)
if ylim is not None:
if log_scale:
plt.ylim(top=ylim)
else:
plt.ylim(bottom=0, top=ylim)
# Save figure.
model, tp = get_model(exp_dir)
sampling = get_sampling(exp_dir)
figname = f'{model}-tp{tp}-{sampling}.{format}'
os.makedirs('./figures', exist_ok=True)
plt.savefig(os.path.join('figures', figname), bbox_inches='tight')
print(f'Saved figure to ./figures/{figname}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('exp_dir', type=str)
parser.add_argument('--duration', type=int, required=True)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--warmup', type=int, default=60)
parser.add_argument('--xlim', type=float, required=False, default=None)
parser.add_argument('--ylim', type=float, required=False, default=None)
parser.add_argument('--log', action='store_true')
parser.add_argument('--format', choices=['png', 'pdf'], default='png')
args = parser.parse_args()
plot_normalized_latency(
args.exp_dir, args.duration, args.seed, args.warmup, args.xlim, args.ylim, args.log, args.format)

52
plot/plot_stats.py Normal file
View File

@ -0,0 +1,52 @@
import os
import pickle
import matplotlib.pyplot as plt
STAT_NAMES = [
'input_lens',
'num_running',
'num_waiting',
'num_preemption',
'gpu_cache_usage',
'cpu_cache_usage',
'num_swapped',
'swap_in_lens',
'swap_out_lens',
]
def plot_stats(output_dir: str):
# Get stats.
with open(os.path.join(output_dir, 'stats.pkl'), 'rb') as f:
stats = pickle.load(f)
timestamps = stats['timestamps']
# Draw one figure for each stat.
num_stats = len(STAT_NAMES)
COLORS = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'purple', 'pink', 'brown', 'gray']
fig, axs = plt.subplots(num_stats, 1, figsize=(10, 2 * num_stats))
for i, stat in enumerate(STAT_NAMES):
data = stats[stat]
if stat in ['gpu_cache_usage', 'cpu_cache_usage']:
data = [x * 100 for x in data]
stat = stat + ' (%)'
axs[i].plot(timestamps, data, color=COLORS[i % len(COLORS)])
axs[i].set_ylabel(stat.replace('_', ' '), fontdict={'fontsize': 12})
axs[i].set_ylim(bottom=0)
plt.xlabel('Time (s)')
plt.tight_layout()
fig_path = os.path.join(output_dir, 'stats.png')
plt.savefig(fig_path)
print(f'Saved stats to {fig_path}')
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('output_dir', type=str, help='Output directory.')
args = parser.parse_args()
plot_stats(args.output_dir)

View File

@ -30,6 +30,7 @@ def main(args: argparse.Namespace):
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,