2023-10-16 00:59:57 -07:00
|
|
|
import argparse
|
|
|
|
import random
|
|
|
|
import time
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2023-11-23 16:31:19 -08:00
|
|
|
from vllm._C import ops
|
2023-10-16 00:59:57 -07:00
|
|
|
|
|
|
|
NUM_BLOCKS = 1024
|
|
|
|
PARTITION_SIZE = 512
|
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
def main(
|
|
|
|
version: str,
|
|
|
|
num_seqs: int,
|
|
|
|
context_len: int,
|
|
|
|
num_query_heads: int,
|
|
|
|
num_kv_heads: int,
|
|
|
|
head_size: int,
|
|
|
|
use_alibi: bool,
|
|
|
|
block_size: int,
|
|
|
|
dtype: torch.dtype,
|
|
|
|
seed: int,
|
|
|
|
do_profile: bool,
|
|
|
|
) -> None:
|
|
|
|
random.seed(seed)
|
|
|
|
torch.random.manual_seed(seed)
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
|
|
|
|
scale = float(1.0 / (head_size**0.5))
|
|
|
|
query = torch.empty(num_seqs,
|
|
|
|
num_query_heads,
|
|
|
|
head_size,
|
|
|
|
dtype=dtype,
|
|
|
|
device="cuda")
|
|
|
|
query.uniform_(-scale, scale)
|
|
|
|
|
|
|
|
assert num_query_heads % num_kv_heads == 0
|
|
|
|
alibi_slopes = None
|
|
|
|
if use_alibi:
|
|
|
|
alibi_slopes = torch.randn(num_query_heads,
|
|
|
|
dtype=torch.float,
|
|
|
|
device="cuda")
|
|
|
|
|
|
|
|
context_lens = [context_len for _ in range(num_seqs)]
|
|
|
|
max_context_len = max(context_lens)
|
|
|
|
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
|
|
|
|
|
|
|
# Create the block tables.
|
|
|
|
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
|
|
|
block_tables = []
|
|
|
|
for _ in range(num_seqs):
|
|
|
|
block_table = [
|
|
|
|
random.randint(0, NUM_BLOCKS - 1)
|
|
|
|
for _ in range(max_num_blocks_per_seq)
|
|
|
|
]
|
|
|
|
block_tables.append(block_table)
|
|
|
|
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
|
|
|
|
|
|
|
# Create the KV cache.
|
|
|
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
|
|
|
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
|
|
|
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
|
|
|
|
key_cache.uniform_(-scale, scale)
|
|
|
|
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
|
|
|
|
value_cache = torch.empty(size=value_cache_shape,
|
|
|
|
dtype=dtype,
|
|
|
|
device="cuda")
|
|
|
|
value_cache.uniform_(-scale, scale)
|
|
|
|
|
|
|
|
# Prepare for the paged attention kernel.
|
|
|
|
output = torch.empty_like(query)
|
|
|
|
if version == "v2":
|
|
|
|
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
|
|
|
PARTITION_SIZE)
|
|
|
|
tmp_output = torch.empty(
|
|
|
|
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
|
|
|
dtype=output.dtype,
|
|
|
|
device=output.device,
|
|
|
|
)
|
|
|
|
exp_sums = torch.empty(
|
|
|
|
size=(num_seqs, num_query_heads, num_partitions),
|
|
|
|
dtype=torch.float32,
|
|
|
|
device=output.device,
|
|
|
|
)
|
|
|
|
max_logits = torch.empty_like(exp_sums)
|
|
|
|
|
|
|
|
def run_benchmark(num_iters: int, profile: bool = False) -> float:
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
if profile:
|
|
|
|
torch.cuda.cudart().cudaProfilerStart()
|
|
|
|
start_time = time.perf_counter()
|
|
|
|
|
|
|
|
for _ in range(num_iters):
|
|
|
|
if version == "v1":
|
2023-11-23 16:31:19 -08:00
|
|
|
ops.paged_attention_v1(
|
2023-10-16 00:59:57 -07:00
|
|
|
output,
|
|
|
|
query,
|
|
|
|
key_cache,
|
|
|
|
value_cache,
|
2023-12-11 02:12:53 +08:00
|
|
|
num_kv_heads,
|
2023-10-16 00:59:57 -07:00
|
|
|
scale,
|
|
|
|
block_tables,
|
|
|
|
context_lens,
|
|
|
|
block_size,
|
|
|
|
max_context_len,
|
|
|
|
alibi_slopes,
|
|
|
|
)
|
|
|
|
elif version == "v2":
|
2023-11-23 16:31:19 -08:00
|
|
|
ops.paged_attention_v2(
|
2023-10-16 00:59:57 -07:00
|
|
|
output,
|
|
|
|
exp_sums,
|
|
|
|
max_logits,
|
|
|
|
tmp_output,
|
|
|
|
query,
|
|
|
|
key_cache,
|
|
|
|
value_cache,
|
2023-12-11 02:12:53 +08:00
|
|
|
num_kv_heads,
|
2023-10-16 00:59:57 -07:00
|
|
|
scale,
|
|
|
|
block_tables,
|
|
|
|
context_lens,
|
|
|
|
block_size,
|
|
|
|
max_context_len,
|
|
|
|
alibi_slopes,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Invalid version: {version}")
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
end_time = time.perf_counter()
|
|
|
|
if profile:
|
|
|
|
torch.cuda.cudart().cudaProfilerStart()
|
|
|
|
return (end_time - start_time) / num_iters
|
|
|
|
|
|
|
|
# Warmup.
|
|
|
|
print("Warming up...")
|
|
|
|
run_benchmark(num_iters=3, profile=False)
|
|
|
|
|
|
|
|
# Benchmark.
|
|
|
|
if do_profile:
|
|
|
|
latency = run_benchmark(num_iters=1, profile=True)
|
|
|
|
else:
|
|
|
|
latency = run_benchmark(num_iters=100, profile=False)
|
|
|
|
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Benchmark the paged attention kernel.")
|
|
|
|
parser.add_argument("--version",
|
|
|
|
type=str,
|
|
|
|
choices=["v1", "v2"],
|
|
|
|
default="v2")
|
|
|
|
parser.add_argument("--batch-size", type=int, default=8)
|
|
|
|
parser.add_argument("--context-len", type=int, default=4096)
|
|
|
|
parser.add_argument("--num-query-heads", type=int, default=64)
|
|
|
|
parser.add_argument("--num-kv-heads", type=int, default=8)
|
|
|
|
parser.add_argument("--head-size",
|
|
|
|
type=int,
|
|
|
|
choices=[64, 80, 96, 112, 128, 256],
|
|
|
|
default=128)
|
|
|
|
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
|
|
|
parser.add_argument("--use-alibi", action="store_true")
|
|
|
|
parser.add_argument("--dtype",
|
|
|
|
type=str,
|
|
|
|
choices=["half", "bfloat16", "float"],
|
|
|
|
default="half")
|
|
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
|
|
parser.add_argument("--profile", action="store_true")
|
|
|
|
args = parser.parse_args()
|
|
|
|
print(args)
|
|
|
|
|
|
|
|
if args.num_query_heads % args.num_kv_heads != 0:
|
|
|
|
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
|
|
|
dtype_to_torch_dtype = {
|
|
|
|
"half": torch.half,
|
|
|
|
"bfloat16": torch.bfloat16,
|
|
|
|
"float": torch.float,
|
|
|
|
}
|
|
|
|
main(
|
|
|
|
version=args.version,
|
|
|
|
num_seqs=args.batch_size,
|
|
|
|
context_len=args.context_len,
|
|
|
|
num_query_heads=args.num_query_heads,
|
|
|
|
num_kv_heads=args.num_kv_heads,
|
|
|
|
head_size=args.head_size,
|
|
|
|
block_size=args.block_size,
|
|
|
|
use_alibi=args.use_alibi,
|
|
|
|
dtype=dtype_to_torch_dtype[args.dtype],
|
|
|
|
seed=args.seed,
|
|
|
|
do_profile=args.profile,
|
|
|
|
)
|