Add throughput benchmarking script (#133)

This commit is contained in:
Woosuk Kwon 2023-05-28 03:20:05 -07:00 committed by GitHub
parent 337871c6fd
commit 211318d44a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 145 additions and 257 deletions

View File

@ -1,165 +0,0 @@
import functools
import random
import time
from typing import List
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
from cacheflow import attention_ops
def benchmark(name, f, num_warmup = 10, num_iters = 100):
for _ in range(num_warmup):
f()
torch.cuda.synchronize()
start = time.time()
for _ in range(num_iters):
f()
torch.cuda.synchronize()
end = time.time()
print(f'{name}: {(end - start) / num_iters * 1000:.3f} ms')
@torch.inference_mode()
def benchmark_multi_query_cached_kv_attention(
query_lens: List[int],
context_lens: List[int],
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
print(f'query_lens: {query_lens}, context_lens: {context_lens}, '
f'num_heads: {num_heads}, head_size: {head_size}, block_size: '
f'{block_size}, num_blocks: {num_blocks}, dtype: {dtype}')
# Create query tensor.
num_queries = len(query_lens)
cu_query_lens = [0]
for query_len in query_lens:
cu_query_lens.append(cu_query_lens[-1] + query_len)
num_total_tokens = cu_query_lens[-1]
qkv = torch.randn(
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
query, _, _ = qkv.unbind(dim=1)
# Create key and value cache.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.randn(
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.randn(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
# Create block tables.
max_context_len = max(context_lens)
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_queries):
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
# Create input and output data structures.
cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda')
context_len_tensor = torch.tensor(context_lens, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
output = torch.empty(
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
# Run our implementation.
def run_ours():
attention_ops.multi_query_cached_kv_attention(
cu_query_lens,
output,
query,
key_cache,
value_cache,
scale,
block_tables,
context_len_tensor,
block_size,
max_context_len,
)
benchmark('Ours', run_ours)
# Upper bound: Flash attention.
# Becuase Flash attention cannot read our own cache,
# we make key and value tensors contiguous.
num_kv_tokens = sum(context_lens)
cu_context_lens = [0]
for context_len in context_lens:
cu_context_lens.append(cu_context_lens[-1] + context_len)
cu_context_lens = torch.tensor(cu_context_lens, dtype=torch.int, device='cuda')
qkv = torch.randn(
num_kv_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
_, key, value = qkv.unbind(dim=1)
ref_output = torch.empty_like(output)
# Run Flash attention.
def run_flash_attn():
_flash_attn_forward(
query,
key,
value,
ref_output,
cu_query_lens,
cu_context_lens,
max(query_lens),
max_context_len,
dropout_p=0.0,
softmax_scale=scale,
causal=True,
return_softmax=False,
)
benchmark('Flash attention', run_flash_attn)
if __name__ == '__main__':
BLOCK_SIZE = 8
NUM_BLOCKS = 1024
DTYPE = torch.half
# LLaMA-13B and OPT-13B
NUM_HEADS = 40
HEAD_SIZE = 128
run_benchmark = functools.partial(
benchmark_multi_query_cached_kv_attention,
num_heads=NUM_HEADS,
head_size=HEAD_SIZE,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
dtype=DTYPE,
)
run_benchmark(
query_lens=[64] * 1,
context_lens=[64] * 1,
)
run_benchmark(
query_lens=[128] * 1,
context_lens=[128] * 1,
)
run_benchmark(
query_lens=[64] * 8,
context_lens=[64] * 8,
)
run_benchmark(
query_lens=[128] * 8,
context_lens=[128] * 8,
)
run_benchmark(
query_lens=[64, 32, 16],
context_lens=[128, 256, 64],
)
run_benchmark(
query_lens=[1024],
context_lens=[1024],
)

View File

@ -1,81 +0,0 @@
import functools
import random
import time
import torch
from cacheflow import cache_ops
def benchmark(name, f, size: int, num_warmup = 10, num_iters = 100):
for _ in range(num_warmup):
f()
torch.cuda.synchronize()
start = time.time()
for _ in range(num_iters):
f()
torch.cuda.synchronize()
end = time.time()
avg_time = (end - start) / num_iters
print(f'[Latency] {name}: {avg_time * 1000:.3f} ms')
print(f'[Throughput] {name}: {size / avg_time / 2 ** 30:.3f} GB/s')
@torch.inference_mode()
def test_gather_cached_kv(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
print(f'num_tokens: {num_tokens}, num_heads: {num_heads}, '
f'head_size: {head_size}, block_size: {block_size}, '
f'num_blocks: {num_blocks}, dtype: {dtype}')
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
qkv = torch.randn(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
_, key, value = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn(
size=value_cache_shape, dtype=dtype, device='cuda')
# Run Flash attention.
def run():
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)
benchmark('gather_cached_kv', run,
size=num_tokens * num_heads * head_size * 2 * qkv.element_size())
if __name__ == '__main__':
BLOCK_SIZE = 8
NUM_BLOCKS = 1024
DTYPE = torch.half
# LLaMA-13B and OPT-13B
NUM_HEADS = 40
HEAD_SIZE = 128
run_benchmark = functools.partial(
test_gather_cached_kv,
num_heads=NUM_HEADS,
head_size=HEAD_SIZE,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
dtype=DTYPE,
)
for i in range(6, 12):
run_benchmark(num_tokens=2 ** i)

8
benchmarks/README.md Normal file
View File

@ -0,0 +1,8 @@
# Benchmarking CacheFlow
## Downloading the ShareGPT dataset
You can download the dataset by running:
```bash
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
```

View File

@ -0,0 +1,104 @@
import argparse
import json
import random
import time
from typing import List, Tuple
from cacheflow import LLM, SamplingParams
from transformers import PreTrainedTokenizerBase
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[List[int], int]]:
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [
data for data in dataset
if len(data["conversations"]) >= 2
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
tokenized_dataset.append((prompt_token_ids[i], output_len))
# Filter out if the prompt length + output length is greater than 2048.
tokenized_dataset = [
(prompt_token_ids, output_len)
for prompt_token_ids, output_len in tokenized_dataset
if len(prompt_token_ids) + output_len <= 2048
]
# Sample the requests.
sampled_requests = random.sample(tokenized_dataset, num_requests)
return sampled_requests
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
llm = LLM(
model=args.model,
tensor_parallel_size=args.tensor_parallel_size,
seed=args.seed,
)
tokenizer = llm.get_tokenizer()
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
# Add the requests to the server.
for prompt_token_ids, output_len in requests:
sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True,
max_tokens=output_len,
)
# FIXME(woosuk): Do not use internal method.
llm._add_request(
prompt="",
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
)
start = time.time()
# FIXME(woosuk): Do use internal method.
llm._run_server(use_tqdm=True)
end = time.time()
total_num_tokens = sum(
len(prompt_token_ids) + output_len
for prompt_token_ids, output_len in requests
)
print(f"Throughput: {total_num_tokens / (end - start):.2f} tokens/s")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--dataset", type=str, required=True,
help="Path to the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", type=int, default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()
main(args)

View File

@ -1,5 +1,5 @@
from cacheflow.entrypoints.llm import LLM from cacheflow.entrypoints.llm import LLM
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput, CompletionOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer from cacheflow.server.llm_server import LLMServer
@ -9,6 +9,7 @@ __all__ = [
"LLM", "LLM",
"SamplingParams", "SamplingParams",
"RequestOutput", "RequestOutput",
"CompletionOutput",
"LLMServer", "LLMServer",
"ServerArgs", "ServerArgs",
"initialize_cluster", "initialize_cluster",

View File

@ -87,6 +87,9 @@ class Scheduler:
def has_unfinished_seqs(self) -> bool: def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped return self.waiting or self.running or self.swapped
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]: def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}

View File

@ -1,5 +1,6 @@
from typing import List, Optional from typing import List, Optional, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from tqdm import tqdm from tqdm import tqdm
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
@ -31,6 +32,11 @@ class LLM:
self.llm_server = LLMServer.from_server_args(server_args) self.llm_server = LLMServer.from_server_args(server_args)
self.request_counter = Counter() self.request_counter = Counter()
def get_tokenizer(
self,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_server.tokenizer
def generate( def generate(
self, self,
prompts: List[str], prompts: List[str],
@ -41,10 +47,6 @@ class LLM:
if sampling_params is None: if sampling_params is None:
# Use default sampling params. # Use default sampling params.
sampling_params = SamplingParams() sampling_params = SamplingParams()
# Initialize tqdm.
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Processed prompts")
# Add requests to the server. # Add requests to the server.
for i in range(len(prompts)): for i in range(len(prompts)):
prompt = prompts[i] prompt = prompts[i]
@ -52,10 +54,24 @@ class LLM:
token_ids = None token_ids = None
else: else:
token_ids = prompt_token_ids[i] token_ids = prompt_token_ids[i]
request_id = str(next(self.request_counter)) self._add_request(prompt, sampling_params, token_ids)
self.llm_server.add_request(request_id, prompt, sampling_params, return self._run_server(use_tqdm)
token_ids)
def _add_request(
self,
prompt: str,
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
) -> None:
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params,
prompt_token_ids)
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_server.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts")
# Run the server. # Run the server.
outputs: List[RequestOutput] = [] outputs: List[RequestOutput] = []
while self.llm_server.has_unfinished_requests(): while self.llm_server.has_unfinished_requests():

View File

@ -151,6 +151,9 @@ class LLMServer:
# Add the sequence group to the scheduler. # Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group) self.scheduler.add_seq_group(seq_group)
def get_num_unfinished_requests(self) -> int:
return self.scheduler.get_num_unfinished_seq_groups()
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
return self.scheduler.has_unfinished_seqs() return self.scheduler.has_unfinished_seqs()

View File

@ -19,9 +19,8 @@ def main(args: argparse.Namespace):
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)), SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
] ]
request_id = 0
# Run the server. # Run the server.
request_id = 0
while True: while True:
# To test iteration-level scheduling, we add one request at each step. # To test iteration-level scheduling, we add one request at each step.
if test_prompts: if test_prompts: