Add throughput benchmarking script (#133)
This commit is contained in:
parent
337871c6fd
commit
211318d44a
@ -1,165 +0,0 @@
|
||||
import functools
|
||||
import random
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||
import torch
|
||||
|
||||
from cacheflow import attention_ops
|
||||
|
||||
|
||||
def benchmark(name, f, num_warmup = 10, num_iters = 100):
|
||||
for _ in range(num_warmup):
|
||||
f()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = time.time()
|
||||
for _ in range(num_iters):
|
||||
f()
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
print(f'{name}: {(end - start) / num_iters * 1000:.3f} ms')
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def benchmark_multi_query_cached_kv_attention(
|
||||
query_lens: List[int],
|
||||
context_lens: List[int],
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
print(f'query_lens: {query_lens}, context_lens: {context_lens}, '
|
||||
f'num_heads: {num_heads}, head_size: {head_size}, block_size: '
|
||||
f'{block_size}, num_blocks: {num_blocks}, dtype: {dtype}')
|
||||
# Create query tensor.
|
||||
num_queries = len(query_lens)
|
||||
cu_query_lens = [0]
|
||||
for query_len in query_lens:
|
||||
cu_query_lens.append(cu_query_lens[-1] + query_len)
|
||||
num_total_tokens = cu_query_lens[-1]
|
||||
qkv = torch.randn(
|
||||
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
query, _, _ = qkv.unbind(dim=1)
|
||||
|
||||
# Create key and value cache.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(
|
||||
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
|
||||
value_block_shape = (num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||
|
||||
# Create block tables.
|
||||
max_context_len = max(context_lens)
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_queries):
|
||||
block_table = [
|
||||
random.randint(0, num_blocks - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||
|
||||
# Create input and output data structures.
|
||||
cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda')
|
||||
context_len_tensor = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
output = torch.empty(
|
||||
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
|
||||
# Run our implementation.
|
||||
def run_ours():
|
||||
attention_ops.multi_query_cached_kv_attention(
|
||||
cu_query_lens,
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
scale,
|
||||
block_tables,
|
||||
context_len_tensor,
|
||||
block_size,
|
||||
max_context_len,
|
||||
)
|
||||
benchmark('Ours', run_ours)
|
||||
|
||||
# Upper bound: Flash attention.
|
||||
# Becuase Flash attention cannot read our own cache,
|
||||
# we make key and value tensors contiguous.
|
||||
num_kv_tokens = sum(context_lens)
|
||||
cu_context_lens = [0]
|
||||
for context_len in context_lens:
|
||||
cu_context_lens.append(cu_context_lens[-1] + context_len)
|
||||
cu_context_lens = torch.tensor(cu_context_lens, dtype=torch.int, device='cuda')
|
||||
qkv = torch.randn(
|
||||
num_kv_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
ref_output = torch.empty_like(output)
|
||||
|
||||
# Run Flash attention.
|
||||
def run_flash_attn():
|
||||
_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
ref_output,
|
||||
cu_query_lens,
|
||||
cu_context_lens,
|
||||
max(query_lens),
|
||||
max_context_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
return_softmax=False,
|
||||
)
|
||||
benchmark('Flash attention', run_flash_attn)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
BLOCK_SIZE = 8
|
||||
NUM_BLOCKS = 1024
|
||||
DTYPE = torch.half
|
||||
|
||||
# LLaMA-13B and OPT-13B
|
||||
NUM_HEADS = 40
|
||||
HEAD_SIZE = 128
|
||||
|
||||
run_benchmark = functools.partial(
|
||||
benchmark_multi_query_cached_kv_attention,
|
||||
num_heads=NUM_HEADS,
|
||||
head_size=HEAD_SIZE,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_blocks=NUM_BLOCKS,
|
||||
dtype=DTYPE,
|
||||
)
|
||||
|
||||
run_benchmark(
|
||||
query_lens=[64] * 1,
|
||||
context_lens=[64] * 1,
|
||||
)
|
||||
run_benchmark(
|
||||
query_lens=[128] * 1,
|
||||
context_lens=[128] * 1,
|
||||
)
|
||||
run_benchmark(
|
||||
query_lens=[64] * 8,
|
||||
context_lens=[64] * 8,
|
||||
)
|
||||
run_benchmark(
|
||||
query_lens=[128] * 8,
|
||||
context_lens=[128] * 8,
|
||||
)
|
||||
run_benchmark(
|
||||
query_lens=[64, 32, 16],
|
||||
context_lens=[128, 256, 64],
|
||||
)
|
||||
run_benchmark(
|
||||
query_lens=[1024],
|
||||
context_lens=[1024],
|
||||
)
|
@ -1,81 +0,0 @@
|
||||
import functools
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from cacheflow import cache_ops
|
||||
|
||||
|
||||
def benchmark(name, f, size: int, num_warmup = 10, num_iters = 100):
|
||||
for _ in range(num_warmup):
|
||||
f()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = time.time()
|
||||
for _ in range(num_iters):
|
||||
f()
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
avg_time = (end - start) / num_iters
|
||||
print(f'[Latency] {name}: {avg_time * 1000:.3f} ms')
|
||||
print(f'[Throughput] {name}: {size / avg_time / 2 ** 30:.3f} GB/s')
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_gather_cached_kv(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
print(f'num_tokens: {num_tokens}, num_heads: {num_heads}, '
|
||||
f'head_size: {head_size}, block_size: {block_size}, '
|
||||
f'num_blocks: {num_blocks}, dtype: {dtype}')
|
||||
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||
|
||||
# Run Flash attention.
|
||||
def run():
|
||||
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
benchmark('gather_cached_kv', run,
|
||||
size=num_tokens * num_heads * head_size * 2 * qkv.element_size())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
BLOCK_SIZE = 8
|
||||
NUM_BLOCKS = 1024
|
||||
DTYPE = torch.half
|
||||
|
||||
# LLaMA-13B and OPT-13B
|
||||
NUM_HEADS = 40
|
||||
HEAD_SIZE = 128
|
||||
|
||||
run_benchmark = functools.partial(
|
||||
test_gather_cached_kv,
|
||||
num_heads=NUM_HEADS,
|
||||
head_size=HEAD_SIZE,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_blocks=NUM_BLOCKS,
|
||||
dtype=DTYPE,
|
||||
)
|
||||
|
||||
for i in range(6, 12):
|
||||
run_benchmark(num_tokens=2 ** i)
|
8
benchmarks/README.md
Normal file
8
benchmarks/README.md
Normal file
@ -0,0 +1,8 @@
|
||||
# Benchmarking CacheFlow
|
||||
|
||||
## Downloading the ShareGPT dataset
|
||||
|
||||
You can download the dataset by running:
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
```
|
104
benchmarks/benchmark_throughput.py
Normal file
104
benchmarks/benchmark_throughput.py
Normal file
@ -0,0 +1,104 @@
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
from cacheflow import LLM, SamplingParams
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
|
||||
def sample_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[Tuple[List[int], int]]:
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [
|
||||
data for data in dataset
|
||||
if len(data["conversations"]) >= 2
|
||||
]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
for data in dataset
|
||||
]
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompts = [prompt for prompt, _ in dataset]
|
||||
prompt_token_ids = tokenizer(prompts).input_ids
|
||||
completions = [completion for _, completion in dataset]
|
||||
completion_token_ids = tokenizer(completions).input_ids
|
||||
tokenized_dataset = []
|
||||
for i in range(len(dataset)):
|
||||
output_len = len(completion_token_ids[i])
|
||||
tokenized_dataset.append((prompt_token_ids[i], output_len))
|
||||
# Filter out if the prompt length + output length is greater than 2048.
|
||||
tokenized_dataset = [
|
||||
(prompt_token_ids, output_len)
|
||||
for prompt_token_ids, output_len in tokenized_dataset
|
||||
if len(prompt_token_ids) + output_len <= 2048
|
||||
]
|
||||
|
||||
# Sample the requests.
|
||||
sampled_requests = random.sample(tokenized_dataset, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
seed=args.seed,
|
||||
)
|
||||
tokenizer = llm.get_tokenizer()
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
|
||||
# Add the requests to the server.
|
||||
for prompt_token_ids, output_len in requests:
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
temperature=0.0 if args.use_beam_search else 1.0,
|
||||
top_p=1.0,
|
||||
use_beam_search=args.use_beam_search,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
)
|
||||
# FIXME(woosuk): Do not use internal method.
|
||||
llm._add_request(
|
||||
prompt="",
|
||||
sampling_params=sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
# FIXME(woosuk): Do use internal method.
|
||||
llm._run_server(use_tqdm=True)
|
||||
end = time.time()
|
||||
total_num_tokens = sum(
|
||||
len(prompt_token_ids) + output_len
|
||||
for prompt_token_ids, output_len in requests
|
||||
)
|
||||
print(f"Throughput: {total_num_tokens / (end - start):.2f} tokens/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||
parser.add_argument("--n", type=int, default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -1,5 +1,5 @@
|
||||
from cacheflow.entrypoints.llm import LLM
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.outputs import RequestOutput, CompletionOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.llm_server import LLMServer
|
||||
@ -9,6 +9,7 @@ __all__ = [
|
||||
"LLM",
|
||||
"SamplingParams",
|
||||
"RequestOutput",
|
||||
"CompletionOutput",
|
||||
"LLMServer",
|
||||
"ServerArgs",
|
||||
"initialize_cluster",
|
||||
|
@ -87,6 +87,9 @@ class Scheduler:
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
return self.waiting or self.running or self.swapped
|
||||
|
||||
def get_num_unfinished_seq_groups(self) -> int:
|
||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||
|
||||
def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
|
||||
# Blocks that need to be swaped or copied before model execution.
|
||||
blocks_to_swap_in: Dict[int, int] = {}
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from tqdm import tqdm
|
||||
|
||||
from cacheflow.outputs import RequestOutput
|
||||
@ -31,6 +32,11 @@ class LLM:
|
||||
self.llm_server = LLMServer.from_server_args(server_args)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
return self.llm_server.tokenizer
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -41,10 +47,6 @@ class LLM:
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
pbar = tqdm(total=len(prompts), desc="Processed prompts")
|
||||
|
||||
# Add requests to the server.
|
||||
for i in range(len(prompts)):
|
||||
prompt = prompts[i]
|
||||
@ -52,10 +54,24 @@ class LLM:
|
||||
token_ids = None
|
||||
else:
|
||||
token_ids = prompt_token_ids[i]
|
||||
self._add_request(prompt, sampling_params, token_ids)
|
||||
return self._run_server(use_tqdm)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_server.add_request(request_id, prompt, sampling_params,
|
||||
token_ids)
|
||||
prompt_token_ids)
|
||||
|
||||
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_server.get_num_unfinished_requests()
|
||||
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
||||
# Run the server.
|
||||
outputs: List[RequestOutput] = []
|
||||
while self.llm_server.has_unfinished_requests():
|
||||
|
@ -151,6 +151,9 @@ class LLMServer:
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.scheduler.get_num_unfinished_seq_groups()
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return self.scheduler.has_unfinished_seqs()
|
||||
|
||||
|
@ -19,9 +19,8 @@ def main(args: argparse.Namespace):
|
||||
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
|
||||
]
|
||||
|
||||
request_id = 0
|
||||
|
||||
# Run the server.
|
||||
request_id = 0
|
||||
while True:
|
||||
# To test iteration-level scheduling, we add one request at each step.
|
||||
if test_prompts:
|
||||
|
Loading…
x
Reference in New Issue
Block a user