2024-03-25 23:59:47 +09:00
|
|
|
import argparse
|
|
|
|
from itertools import accumulate
|
2024-06-15 12:45:31 +08:00
|
|
|
from typing import List, Optional
|
2024-03-13 13:45:26 -07:00
|
|
|
|
|
|
|
import nvtx
|
2024-03-25 23:59:47 +09:00
|
|
|
import torch
|
|
|
|
|
2024-06-15 12:45:31 +08:00
|
|
|
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
|
|
|
|
get_rope)
|
2024-03-13 13:45:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def benchmark_rope_kernels_multi_lora(
|
|
|
|
is_neox_style: bool,
|
|
|
|
batch_size: int,
|
|
|
|
seq_len: int,
|
|
|
|
num_heads: int,
|
|
|
|
head_size: int,
|
|
|
|
rotary_dim: Optional[int],
|
|
|
|
dtype: torch.dtype,
|
|
|
|
seed: int,
|
|
|
|
device: str,
|
|
|
|
max_position: int = 8192,
|
|
|
|
base: int = 10000,
|
|
|
|
) -> None:
|
|
|
|
torch.random.manual_seed(seed)
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
torch.set_default_device(device)
|
|
|
|
if rotary_dim is None:
|
|
|
|
rotary_dim = head_size
|
|
|
|
# silulating serving 4 LoRAs
|
|
|
|
scaling_factors = [1, 2, 4, 8]
|
|
|
|
# batched RoPE can take multiple scaling factors
|
|
|
|
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
|
|
|
|
is_neox_style, {
|
|
|
|
"type": "linear",
|
|
|
|
"factor": tuple(scaling_factors)
|
|
|
|
})
|
|
|
|
# non-batched RoPE takes only one scaling factor, we create multiple
|
|
|
|
# instances to simulate the same behavior
|
2024-06-15 12:45:31 +08:00
|
|
|
non_batched_ropes: List[RotaryEmbedding] = []
|
2024-03-13 13:45:26 -07:00
|
|
|
for scaling_factor in scaling_factors:
|
|
|
|
non_batched_ropes.append(
|
|
|
|
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
|
|
|
{
|
|
|
|
"type": "linear",
|
|
|
|
"factor": (scaling_factor, )
|
|
|
|
}))
|
|
|
|
|
|
|
|
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
|
|
|
query = torch.randn(batch_size,
|
|
|
|
seq_len,
|
|
|
|
num_heads * head_size,
|
|
|
|
dtype=dtype)
|
|
|
|
key = torch.randn_like(query)
|
|
|
|
|
|
|
|
# create query offsets for batched RoPE, we concat multiple kv cache
|
|
|
|
# together and each query needs to find the right kv cache of its type
|
|
|
|
offset_map = torch.tensor(
|
|
|
|
list(
|
|
|
|
accumulate([0] + [
|
|
|
|
max_position * scaling_factor * 2
|
|
|
|
for scaling_factor in scaling_factors[:-1]
|
|
|
|
])))
|
|
|
|
query_types = torch.randint(0,
|
|
|
|
len(scaling_factors), (batch_size, seq_len),
|
|
|
|
device=device)
|
|
|
|
# map query types to offsets
|
|
|
|
query_offsets = offset_map[query_types]
|
|
|
|
# the kernel takes flattened offsets
|
|
|
|
flatten_offsets = query_offsets.flatten()
|
|
|
|
|
|
|
|
# batched queries of the same type together for non-batched RoPE
|
|
|
|
queries = [query[query_types == i] for i in range(len(scaling_factors))]
|
|
|
|
keys = [key[query_types == i] for i in range(len(scaling_factors))]
|
|
|
|
packed_qkr = zip(queries, keys, non_batched_ropes)
|
|
|
|
# synchronize before start timing
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
with nvtx.annotate("non-batched", color="yellow"):
|
|
|
|
for q, k, r in packed_qkr:
|
|
|
|
r.forward(positions, q, k)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
with nvtx.annotate("batched", color="green"):
|
|
|
|
batched_rope.forward(positions, query, key, flatten_offsets)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Benchmark the rotary embedding kernels.")
|
|
|
|
parser.add_argument("--is-neox-style", type=bool, default=True)
|
|
|
|
parser.add_argument("--batch-size", type=int, default=16)
|
|
|
|
parser.add_argument("--seq-len", type=int, default=512)
|
|
|
|
parser.add_argument("--num-heads", type=int, default=8)
|
|
|
|
parser.add_argument("--head-size",
|
|
|
|
type=int,
|
2024-05-31 10:24:41 +08:00
|
|
|
choices=[64, 80, 96, 112, 128, 192, 256],
|
2024-03-13 13:45:26 -07:00
|
|
|
default=128)
|
|
|
|
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
|
|
|
parser.add_argument("--dtype",
|
|
|
|
type=str,
|
|
|
|
choices=["bfloat16", "float"],
|
|
|
|
default="float")
|
|
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
|
|
parser.add_argument("--device",
|
|
|
|
type=str,
|
|
|
|
choices=["cuda:0", "cuda:1"],
|
|
|
|
default="cuda:0")
|
|
|
|
args = parser.parse_args()
|
|
|
|
print(args)
|
|
|
|
|
|
|
|
benchmark_rope_kernels_multi_lora(
|
|
|
|
is_neox_style=args.is_neox_style,
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
seq_len=args.seq_len,
|
|
|
|
num_heads=args.num_heads,
|
|
|
|
head_size=args.head_size,
|
|
|
|
rotary_dim=args.rotary_dim,
|
|
|
|
dtype=getattr(torch, args.dtype),
|
|
|
|
seed=args.seed,
|
|
|
|
device=args.device,
|
|
|
|
)
|