2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2025-01-16 21:21:40 +05:30
|
|
|
import argparse
|
|
|
|
import copy
|
|
|
|
import json
|
|
|
|
import pickle
|
|
|
|
import time
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from enum import Enum, auto
|
|
|
|
from itertools import product
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.utils.benchmark as TBenchmark
|
|
|
|
from torch.utils.benchmark import Measurement as TMeasurement
|
|
|
|
from utils import ArgPool, Bench, CudaGraphBenchParams
|
|
|
|
from weight_shapes import WEIGHT_SHAPES
|
|
|
|
|
|
|
|
from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand
|
|
|
|
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
|
|
|
|
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
|
|
|
|
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
|
|
|
|
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
|
|
|
|
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
|
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
|
|
|
|
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
|
|
|
DEFAULT_TP_SIZES = [1]
|
|
|
|
DEFAULT_BATCH_SIZES = [
|
|
|
|
1, 16, 32, 64, 128, 192, 256, 320, 384, 448, 512, 640, 768, 896, 1024,
|
|
|
|
2048, 3072, 4096, 5120, 6144, 7168, 8192
|
|
|
|
]
|
|
|
|
DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384]
|
|
|
|
DEFAULT_LORA_RANKS = [16]
|
|
|
|
DEFAULT_NUM_LORAS = [1, 2, 3, 4]
|
|
|
|
DEFAULT_SORT_BY_LORA_IDS = [False, True]
|
|
|
|
DEFAULT_SEQ_LENGTHS = [1]
|
|
|
|
DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False]
|
|
|
|
|
|
|
|
|
|
|
|
# Utilities
|
|
|
|
def dtype_to_str(dtype: torch.dtype):
|
|
|
|
if dtype == torch.float16:
|
|
|
|
return "f16"
|
|
|
|
if dtype == torch.bfloat16:
|
|
|
|
return "bf16"
|
|
|
|
if dtype == torch.float32:
|
|
|
|
return "f32"
|
|
|
|
raise ValueError(f"Unsupported dtype {dtype}")
|
|
|
|
|
|
|
|
|
|
|
|
def make_rand_lora_weight_tensor(k: int,
|
|
|
|
n: int,
|
|
|
|
num_loras: int,
|
|
|
|
dtype: torch.dtype,
|
|
|
|
device: str = "cuda") -> torch.Tensor:
|
|
|
|
|
|
|
|
# LoRA weights column major
|
|
|
|
return torch.rand((num_loras, n, k), dtype=dtype).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
def make_rand_tensors(
|
|
|
|
a_shape: Tuple[int],
|
|
|
|
b_shape: Tuple[int],
|
|
|
|
c_shape: Tuple[int],
|
|
|
|
a_dtype: torch.dtype,
|
|
|
|
b_dtype: torch.dtype,
|
|
|
|
c_dtype: torch.dtype,
|
|
|
|
num_slices: int,
|
|
|
|
device: str = "cuda",
|
|
|
|
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
|
|
|
"""
|
|
|
|
Make LoRA input/output matrices.
|
|
|
|
"""
|
|
|
|
A = torch.rand(a_shape, dtype=a_dtype).to(device)
|
|
|
|
|
|
|
|
# LoRA weights column major
|
|
|
|
Bs = [
|
|
|
|
torch.rand(b_shape, dtype=b_dtype).to(device)
|
|
|
|
for _ in range(num_slices)
|
|
|
|
]
|
|
|
|
|
|
|
|
C = torch.zeros(c_shape, dtype=c_dtype).to(device)
|
|
|
|
return A, Bs, C
|
|
|
|
|
|
|
|
|
|
|
|
def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int,
|
|
|
|
sort_by_lora_id: bool,
|
|
|
|
device: str) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
All prompts are mapped to a Lora ID in range [0, num_active_loras).
|
|
|
|
where 0 refers to first lora, 1 refers to second lora and so on.
|
|
|
|
"""
|
|
|
|
assert num_active_loras > 0
|
|
|
|
|
|
|
|
if not sort_by_lora_id:
|
|
|
|
return torch.randint(0,
|
|
|
|
num_active_loras, (num_prompts, ),
|
|
|
|
dtype=torch.long)
|
|
|
|
|
|
|
|
# Divide LoRAs equally and in order.
|
|
|
|
part_size = num_prompts // num_active_loras
|
|
|
|
part_size = max(part_size, 1)
|
|
|
|
|
|
|
|
lora_id = 0
|
|
|
|
prompt_lora_mapping = []
|
|
|
|
while len(prompt_lora_mapping) < num_prompts:
|
|
|
|
prompt_lora_mapping.extend([lora_id] * part_size)
|
|
|
|
lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id
|
|
|
|
return torch.tensor(prompt_lora_mapping[:num_prompts],
|
|
|
|
dtype=torch.long,
|
|
|
|
device=device)
|
|
|
|
|
|
|
|
|
|
|
|
def make_token_lora_mapping(num_tokens: int, num_prompts: int,
|
|
|
|
prompt_lora_mapping: torch.Tensor,
|
|
|
|
seq_len_tensor: torch.Tensor, device: str):
|
|
|
|
"""
|
|
|
|
Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor
|
|
|
|
"""
|
|
|
|
assert prompt_lora_mapping.shape[0] == num_prompts
|
|
|
|
|
|
|
|
# token to lora index mapping
|
|
|
|
token_lora_mapping = [0] * num_tokens
|
|
|
|
current_offset = 0
|
|
|
|
for b_id in range(num_prompts):
|
|
|
|
lora_index = prompt_lora_mapping[b_id].item()
|
|
|
|
s = current_offset
|
|
|
|
e = s + seq_len_tensor[b_id].item()
|
|
|
|
token_lora_mapping[s:e] = [lora_index] * (e - s)
|
|
|
|
current_offset += seq_len_tensor[b_id].item()
|
|
|
|
|
|
|
|
return torch.tensor(token_lora_mapping, dtype=torch.long, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
|
|
|
|
lora_weights: List[torch.Tensor],
|
|
|
|
seq_lens_cpu: torch.Tensor,
|
|
|
|
prompt_lora_mapping_cpu: torch.Tensor, scaling: float,
|
|
|
|
add_inputs: Optional[bool]):
|
|
|
|
"""
|
|
|
|
Torch group gemm reference implementation to test correctness of
|
|
|
|
benchmarking operations.
|
|
|
|
"""
|
|
|
|
batches = seq_lens_cpu.size(0)
|
|
|
|
out_list = []
|
|
|
|
current_offset = 0
|
|
|
|
for lora_index, b_length in zip(range(batches), seq_lens_cpu):
|
|
|
|
x = input[current_offset:b_length + current_offset, :]
|
|
|
|
current_offset += b_length
|
|
|
|
w = lora_weights[prompt_lora_mapping_cpu[lora_index]]
|
|
|
|
result = torch.nn.functional.linear(x, w)
|
|
|
|
result *= scaling
|
|
|
|
out_list.append(result)
|
|
|
|
torch.cat(out_list, dim=0)
|
|
|
|
|
|
|
|
cat_result = torch.cat(out_list, dim=0)
|
|
|
|
|
|
|
|
if add_inputs:
|
|
|
|
ref_out += cat_result
|
|
|
|
else:
|
|
|
|
ref_out.copy_(cat_result)
|
|
|
|
|
|
|
|
|
|
|
|
class OpType(Enum):
|
|
|
|
"""
|
|
|
|
LoRA Ops to benchmark and its properties.
|
|
|
|
"""
|
|
|
|
SGMV_SHRINK = auto()
|
|
|
|
BGMV_SHRINK = auto()
|
|
|
|
SGMV_EXPAND = auto()
|
|
|
|
BGMV_EXPAND = auto()
|
|
|
|
BGMV_EXPAND_SLICE = auto()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def from_str(s: str) -> "OpType":
|
|
|
|
if s.lower() == 'sgmv_shrink':
|
|
|
|
return OpType.SGMV_SHRINK
|
|
|
|
if s.lower() == 'sgmv_expand':
|
|
|
|
return OpType.SGMV_EXPAND
|
|
|
|
if s.lower() == 'bgmv_shrink':
|
|
|
|
return OpType.BGMV_SHRINK
|
|
|
|
if s.lower() == 'bgmv_expand':
|
|
|
|
return OpType.BGMV_EXPAND
|
|
|
|
if s.lower() == "bgmv_expand_slice":
|
|
|
|
return OpType.BGMV_EXPAND_SLICE
|
|
|
|
raise ValueError(f"Unrecognized str {s} to convert to OpType")
|
|
|
|
|
|
|
|
def is_shrink_fn(self) -> bool:
|
|
|
|
return self in [OpType.SGMV_SHRINK, OpType.BGMV_SHRINK]
|
|
|
|
|
|
|
|
def is_expand_fn(self) -> bool:
|
|
|
|
return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND]
|
|
|
|
|
|
|
|
def is_prefill_op(self) -> bool:
|
|
|
|
return self in [OpType.SGMV_SHRINK, OpType.SGMV_EXPAND]
|
|
|
|
|
|
|
|
def is_decode_op(self) -> bool:
|
|
|
|
return self in [
|
|
|
|
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE
|
|
|
|
]
|
|
|
|
|
|
|
|
def is_expand_slice_fn(self) -> bool:
|
|
|
|
return self in [OpType.BGMV_EXPAND_SLICE]
|
|
|
|
|
|
|
|
def num_slices(self) -> List[int]:
|
|
|
|
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]:
|
|
|
|
# SGMV kernels supports slices
|
|
|
|
return [1, 2, 3]
|
|
|
|
if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
|
|
|
|
return [1]
|
|
|
|
if self in [OpType.BGMV_EXPAND_SLICE]:
|
|
|
|
return [2, 3]
|
|
|
|
raise ValueError(f"Unrecognized OpType {self}")
|
|
|
|
|
|
|
|
def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
|
|
|
|
lora_rank: int) -> Tuple[int, int, int]:
|
|
|
|
num_tokens = batch_size * seq_length
|
|
|
|
if self.is_shrink_fn():
|
|
|
|
m = num_tokens
|
|
|
|
k = hidden_size
|
|
|
|
n = lora_rank
|
|
|
|
else:
|
|
|
|
assert self.is_expand_fn() or self.is_expand_slice_fn()
|
|
|
|
m = num_tokens
|
|
|
|
k = lora_rank
|
|
|
|
n = hidden_size
|
|
|
|
return m, k, n
|
|
|
|
|
|
|
|
def matmul_dtypes(
|
|
|
|
self, op_dtype: torch.dtype
|
|
|
|
) -> Tuple[torch.dtype, torch.dtype, torch.dtype]:
|
|
|
|
"""
|
|
|
|
return a type, b type and c type for A x B = C
|
|
|
|
"""
|
|
|
|
if self.is_shrink_fn():
|
|
|
|
return op_dtype, op_dtype, torch.float32
|
|
|
|
else:
|
|
|
|
assert self.is_expand_fn() or self.is_expand_slice_fn()
|
|
|
|
return torch.float32, op_dtype, op_dtype
|
|
|
|
|
|
|
|
def matmul_shapes(
|
|
|
|
self, batch_size: int, seq_length: int, hidden_size: int,
|
|
|
|
lora_rank: int, num_loras: int,
|
|
|
|
num_slices: int) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]:
|
|
|
|
"""
|
|
|
|
Given num_slices, return the shapes of the A, B, and C matrices
|
|
|
|
in A x B = C, for the op_type
|
|
|
|
"""
|
|
|
|
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
|
|
|
|
|
|
|
|
b_shape = (num_loras, n, k) # col-major
|
|
|
|
if self == OpType.SGMV_SHRINK:
|
|
|
|
# SGMV shrink supports num_slices inherently in the kernel
|
|
|
|
return ((m, k), b_shape, (num_slices, m, n))
|
|
|
|
if self == OpType.SGMV_EXPAND:
|
|
|
|
# SGMV expand supports num_slices inherently in the kernel
|
|
|
|
return ((num_slices, m, k), b_shape, (m, n * num_slices))
|
|
|
|
if self == OpType.BGMV_SHRINK:
|
|
|
|
return ((m, k), b_shape, (m, n))
|
|
|
|
if self == OpType.BGMV_EXPAND:
|
|
|
|
return ((m, k), b_shape, (m, n))
|
|
|
|
if self == OpType.BGMV_EXPAND_SLICE:
|
|
|
|
return ((num_slices, m, k), b_shape, (m, n * num_slices))
|
|
|
|
|
|
|
|
raise ValueError(f"Unrecognized op_type {self}")
|
|
|
|
|
|
|
|
def bench_fn(self) -> Callable:
|
|
|
|
|
|
|
|
def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]):
|
|
|
|
for x in kwargs_list:
|
|
|
|
bgmv_expand_slice(**x)
|
|
|
|
|
|
|
|
if self == OpType.SGMV_SHRINK:
|
|
|
|
return sgmv_shrink
|
|
|
|
if self == OpType.SGMV_EXPAND:
|
|
|
|
return sgmv_expand
|
|
|
|
if self == OpType.BGMV_SHRINK:
|
|
|
|
return bgmv_shrink
|
|
|
|
if self == OpType.BGMV_EXPAND:
|
|
|
|
return bgmv_expand
|
|
|
|
if self == OpType.BGMV_EXPAND_SLICE:
|
|
|
|
return emulate_bgmv_expand_slice
|
|
|
|
raise ValueError(f"Unrecognized optype {self}")
|
|
|
|
|
|
|
|
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
|
|
|
|
lora_weights: List[torch.Tensor],
|
|
|
|
**kwargs) -> Callable:
|
|
|
|
"""Each benchmark operation expected the input, lora_weights and outputs
|
|
|
|
in a slightly different format. Refer to self.matmul_shapes().
|
|
|
|
run_ref_group_gemm accounts for those differences in executing a
|
|
|
|
reference group gemm for correctness testing.
|
|
|
|
"""
|
|
|
|
w_dtype = lora_weights[0].dtype
|
|
|
|
num_slices = len(lora_weights)
|
|
|
|
if self == OpType.SGMV_SHRINK:
|
|
|
|
for slice_idx in range(num_slices):
|
|
|
|
ref_group_gemm(ref_out=output[slice_idx, :],
|
|
|
|
input=input,
|
|
|
|
lora_weights=lora_weights[slice_idx],
|
|
|
|
**kwargs)
|
|
|
|
if self == OpType.SGMV_EXPAND:
|
|
|
|
hidden_size = lora_weights[0].shape[1]
|
|
|
|
for slice_idx in range(num_slices):
|
|
|
|
slice_offset = slice_idx * hidden_size
|
|
|
|
ref_group_gemm(
|
|
|
|
ref_out=output[:, slice_offset:slice_offset + hidden_size],
|
|
|
|
input=input[slice_idx].clone().to(dtype=w_dtype),
|
|
|
|
lora_weights=lora_weights[slice_idx],
|
|
|
|
**kwargs)
|
|
|
|
if self == OpType.BGMV_SHRINK:
|
|
|
|
assert num_slices == 1
|
|
|
|
ref_group_gemm(ref_out=output,
|
|
|
|
input=input,
|
|
|
|
lora_weights=lora_weights[0],
|
|
|
|
**kwargs)
|
|
|
|
if self == OpType.BGMV_EXPAND:
|
|
|
|
assert num_slices == 1
|
|
|
|
ref_group_gemm(ref_out=output,
|
|
|
|
input=input.clone().to(dtype=w_dtype),
|
|
|
|
lora_weights=lora_weights[0],
|
|
|
|
**kwargs)
|
|
|
|
if self == OpType.BGMV_EXPAND_SLICE:
|
|
|
|
hidden_size = lora_weights[0].shape[1]
|
|
|
|
for slice_idx in range(num_slices):
|
|
|
|
slice_offset = slice_idx * hidden_size
|
|
|
|
ref_group_gemm(
|
|
|
|
ref_out=output[:, slice_offset:slice_offset + hidden_size],
|
|
|
|
input=input[slice_idx].clone().to(dtype=w_dtype),
|
|
|
|
lora_weights=lora_weights[slice_idx],
|
|
|
|
**kwargs)
|
|
|
|
raise ValueError(f"Unrecognized optype {self}")
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class BenchmarkContext:
|
|
|
|
"""
|
|
|
|
LoRA benchmark context
|
|
|
|
"""
|
|
|
|
batch_size: int
|
|
|
|
hidden_size: int
|
|
|
|
num_loras: int
|
|
|
|
num_active_loras: int
|
|
|
|
lora_rank: int
|
|
|
|
sort_by_lora_id: bool
|
|
|
|
dtype: torch.dtype
|
|
|
|
seq_length: Optional[int] = None
|
|
|
|
num_slices: Optional[int] = None # num_slices for slice based ops
|
|
|
|
|
|
|
|
def with_seq_length(self, seq_length: int) -> "BenchmarkContext":
|
|
|
|
ctx = copy.copy(self)
|
|
|
|
ctx.seq_length = seq_length
|
|
|
|
return ctx
|
|
|
|
|
|
|
|
def with_num_slices(self, num_slices: int) -> "BenchmarkContext":
|
|
|
|
ctx = copy.copy(self)
|
|
|
|
ctx.num_slices = num_slices
|
|
|
|
return ctx
|
|
|
|
|
|
|
|
def bench_label(self) -> str:
|
|
|
|
return f"lora-{self.dtype}"
|
|
|
|
|
|
|
|
def bench_sublabel(self, op_type: OpType) -> str:
|
|
|
|
m, k, n = op_type.mkn(self.batch_size, self.seq_length,
|
|
|
|
self.hidden_size, self.lora_rank)
|
|
|
|
desc = {
|
|
|
|
'bs': self.batch_size,
|
|
|
|
'sl': self.seq_length,
|
|
|
|
'm': m,
|
|
|
|
'k': k,
|
|
|
|
'n': n,
|
|
|
|
'num_loras': self.num_loras,
|
|
|
|
'sort_by_lora': self.sort_by_lora_id,
|
|
|
|
'num_slices': self.num_slices,
|
|
|
|
}
|
|
|
|
return json.dumps(desc)
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class BenchmarkTensors:
|
|
|
|
"""
|
|
|
|
Input/Output tensors used for benchmarks
|
|
|
|
"""
|
|
|
|
# matmul tensors
|
|
|
|
input: torch.Tensor
|
|
|
|
lora_weights_lst: List[torch.Tensor]
|
|
|
|
output: torch.Tensor
|
|
|
|
# metadata tensors
|
|
|
|
seq_lens: torch.Tensor
|
|
|
|
seq_start_loc: torch.Tensor
|
|
|
|
prompt_lora_mapping: torch.Tensor
|
|
|
|
token_lora_mapping: torch.Tensor
|
|
|
|
|
|
|
|
def io_types(self) -> str:
|
|
|
|
return (f"{dtype_to_str(self.input.dtype)}x"
|
|
|
|
f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>"
|
|
|
|
f"{dtype_to_str(self.output.dtype)}")
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def make(ctx: BenchmarkContext,
|
|
|
|
op_type: OpType,
|
|
|
|
device: str = "cuda") -> "BenchmarkTensors":
|
|
|
|
|
|
|
|
# Make input / output matmul tensors.
|
|
|
|
a_shape, b_shape, c_shape = op_type.matmul_shapes(
|
|
|
|
ctx.batch_size, ctx.seq_length, ctx.hidden_size, ctx.lora_rank,
|
|
|
|
ctx.num_loras, ctx.num_slices)
|
|
|
|
a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype)
|
|
|
|
input_tensor, lora_weights, output_tensor = \
|
|
|
|
make_rand_tensors(a_shape, b_shape, c_shape, a_type, b_type, c_type,
|
|
|
|
num_slices = ctx.num_slices)
|
|
|
|
|
|
|
|
# Make metadata tensors.
|
|
|
|
# Keep the metadata tensors in the CPU for further processing if needed.
|
|
|
|
# The tensors get moved to the GPU before benchmarking.
|
|
|
|
assert ctx.num_active_loras <= ctx.num_loras
|
|
|
|
total_tokens = ctx.batch_size * ctx.seq_length
|
|
|
|
|
|
|
|
# Prepare seq lens tensor
|
|
|
|
seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1,
|
|
|
|
(ctx.batch_size, ))
|
|
|
|
# Prepare seq_start_loc tensor
|
|
|
|
seq_start_loc_tensor = torch.cumsum(torch.tensor(
|
|
|
|
[0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
|
|
|
|
dim=0)
|
|
|
|
assert total_tokens == seq_len_tensor.sum()
|
|
|
|
# Prepare prompt lora indices tensor
|
|
|
|
prompt_lora_indices_tensor = make_prompt_lora_mapping(
|
|
|
|
ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu")
|
|
|
|
# Prepare token lora indices tensor
|
|
|
|
token_lora_indices_tensor = make_token_lora_mapping(
|
|
|
|
total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
|
|
|
|
seq_len_tensor, "cpu")
|
|
|
|
|
|
|
|
return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
|
|
|
|
seq_len_tensor, seq_start_loc_tensor,
|
|
|
|
prompt_lora_indices_tensor,
|
|
|
|
token_lora_indices_tensor)
|
|
|
|
|
|
|
|
def sanity_check(self) -> None:
|
|
|
|
"""
|
|
|
|
Fails asserts when non-conformality is detected.
|
|
|
|
"""
|
|
|
|
num_tokens = self.input.shape[-2]
|
|
|
|
# check metadata tensors
|
|
|
|
assert torch.sum(self.seq_lens) == num_tokens
|
|
|
|
num_seqs = self.seq_lens.shape[0]
|
|
|
|
assert self.seq_start_loc.shape[0] == num_seqs
|
|
|
|
assert self.prompt_lora_mapping.shape[0] == num_seqs
|
|
|
|
assert self.token_lora_mapping.shape[0] == num_tokens
|
|
|
|
|
|
|
|
def to_device(self, device: str):
|
|
|
|
"""
|
|
|
|
Transfer tensors to device if the tensors aren't already on the device
|
|
|
|
"""
|
|
|
|
|
|
|
|
def to_device(tensor: torch.Tensor):
|
|
|
|
if tensor.device != device:
|
|
|
|
tensor = tensor.to(device=device)
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
self.input = to_device(self.input)
|
|
|
|
self.output = to_device(self.output)
|
|
|
|
self.seq_lens = to_device(self.seq_lens)
|
|
|
|
self.seq_start_loc = to_device(self.seq_start_loc)
|
|
|
|
self.prompt_lora_mapping = to_device(self.prompt_lora_mapping)
|
|
|
|
self.token_lora_mapping = to_device(self.token_lora_mapping)
|
|
|
|
for i in range(len(self.lora_weights_lst)):
|
|
|
|
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
|
|
|
|
|
|
|
|
def metadata(self) -> Tuple[int, int, int]:
|
|
|
|
"""
|
|
|
|
Return num_seqs, num_tokens and max_seq_len
|
|
|
|
"""
|
|
|
|
num_seqs = self.seq_lens.shape[0]
|
|
|
|
num_tokens = self.token_lora_mapping.shape[0]
|
|
|
|
max_seq_len = torch.max(self.seq_lens).item()
|
|
|
|
num_slices = len(self.lora_weights_lst)
|
|
|
|
return num_seqs, num_tokens, max_seq_len, num_slices
|
|
|
|
|
|
|
|
def convert_to_sgmv_benchmark_tensors(self):
|
|
|
|
"""
|
|
|
|
For sgmv punica kernels, when consecutive sequences have the
|
|
|
|
same LoRA ID, we just merge them together.
|
|
|
|
This happens in punica.py::compute_metadata
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Collapse seq_lens and seq_start_loc
|
|
|
|
_, seq_lens = torch.unique_consecutive(self.token_lora_mapping,
|
|
|
|
return_counts=True)
|
|
|
|
cum_result = torch.cumsum(seq_lens, dim=0)
|
|
|
|
seq_start_loc = torch.zeros_like(seq_lens)
|
|
|
|
seq_start_loc[1:].copy_(cum_result[:-1])
|
|
|
|
|
|
|
|
# Collapse prompt mapping
|
|
|
|
prompt_lora_mapping = torch.unique_consecutive(
|
|
|
|
self.prompt_lora_mapping)
|
|
|
|
|
|
|
|
assert torch.sum(seq_lens) == torch.sum(self.seq_lens), \
|
|
|
|
f"dont match - new {torch.sum(seq_lens)} vs {torch.sum(self.seq_lens)}"
|
|
|
|
|
|
|
|
self.prompt_lora_mapping = prompt_lora_mapping.to(
|
|
|
|
dtype=self.prompt_lora_mapping.dtype)
|
|
|
|
self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype)
|
|
|
|
self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype)
|
|
|
|
|
|
|
|
def as_sgmv_shrink_kwargs(self) -> Dict[str, Any]:
|
|
|
|
self.convert_to_sgmv_benchmark_tensors()
|
|
|
|
self.sanity_check()
|
|
|
|
self.to_device(self.input.device)
|
|
|
|
|
|
|
|
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
|
|
|
|
|
|
|
|
# Sanity check matrix shapes.
|
|
|
|
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
|
|
|
0].shape, self.output.shape
|
|
|
|
# Expected input shape [num_tokens, hidden_size]
|
|
|
|
assert len(i_shape) == 2
|
|
|
|
assert i_shape[0] == num_tokens
|
|
|
|
hidden_size = i_shape[1]
|
|
|
|
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
|
|
|
|
assert len(lw_shape) == 3
|
|
|
|
assert lw_shape[2] == hidden_size
|
|
|
|
lora_rank = lw_shape[1]
|
|
|
|
# Expected output shape [num_slices, num_tokens, lora_rank]
|
|
|
|
assert len(o_shape) == 3
|
|
|
|
assert o_shape == (num_slices, num_tokens, lora_rank)
|
|
|
|
|
|
|
|
return {
|
|
|
|
'inputs': self.input,
|
|
|
|
'lora_a_weights': self.lora_weights_lst,
|
|
|
|
'output_tensor': self.output,
|
|
|
|
'b_seq_start_loc': self.seq_start_loc,
|
|
|
|
'seq_len_tensor': self.seq_lens,
|
|
|
|
'lora_indices_tensor': self.prompt_lora_mapping,
|
|
|
|
'batches': num_seqs,
|
|
|
|
'max_seq_length': max_seq_len,
|
|
|
|
'token_nums': num_tokens,
|
|
|
|
'scaling': 1.0,
|
|
|
|
}
|
|
|
|
|
|
|
|
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
self.convert_to_sgmv_benchmark_tensors()
|
|
|
|
self.sanity_check()
|
|
|
|
self.to_device(self.input.device)
|
|
|
|
|
|
|
|
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
|
|
|
|
|
|
|
|
# Sanity check matrix shapes.
|
|
|
|
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
|
|
|
0].shape, self.output.shape
|
|
|
|
# Expected input shape : [num_slices, num_tokens, lora_rank]
|
|
|
|
assert len(i_shape) == 3
|
|
|
|
assert i_shape[0] == num_slices
|
|
|
|
assert i_shape[1] == num_tokens
|
|
|
|
lora_rank = i_shape[2]
|
|
|
|
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
|
|
|
|
assert len(lw_shape) == 3
|
|
|
|
assert lw_shape[2] == lora_rank
|
|
|
|
hidden_size = lw_shape[1]
|
|
|
|
# Expected output shape : [num_tokens, hidden_size * num_slices]
|
|
|
|
assert len(o_shape) == 2
|
|
|
|
assert o_shape == (num_tokens, hidden_size * num_slices)
|
|
|
|
|
|
|
|
return {
|
|
|
|
'inputs': self.input,
|
|
|
|
'lora_b_weights': self.lora_weights_lst,
|
|
|
|
'output_tensor': self.output,
|
|
|
|
'b_seq_start_loc': self.seq_start_loc,
|
|
|
|
'seq_len_tensor': self.seq_lens,
|
|
|
|
'lora_indices_tensor': self.prompt_lora_mapping,
|
|
|
|
'batches': num_seqs,
|
|
|
|
'max_seq_length': max_seq_len,
|
|
|
|
'token_nums': num_tokens,
|
|
|
|
'offset_start': 0,
|
|
|
|
'add_inputs': add_inputs,
|
|
|
|
}
|
|
|
|
|
|
|
|
def as_bgmv_shrink_kwargs(self) -> Dict[str, Any]:
|
|
|
|
assert len(self.lora_weights_lst) == 1
|
|
|
|
self.to_device(self.input.device)
|
|
|
|
|
|
|
|
_, num_tokens, _, _ = self.metadata()
|
|
|
|
# Sanity check shapes
|
|
|
|
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
|
|
|
0].shape, self.output.shape
|
|
|
|
# Expected input shape [num_tokens, hidden_size]
|
|
|
|
assert len(i_shape) == 2
|
|
|
|
assert i_shape[0] == num_tokens
|
|
|
|
hidden_size = i_shape[1]
|
|
|
|
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
|
|
|
|
assert len(lw_shape) == 3
|
|
|
|
assert lw_shape[2] == hidden_size
|
|
|
|
lora_rank = lw_shape[1]
|
|
|
|
# Expected output shape [num_tokens, lora_rank]
|
|
|
|
assert len(o_shape) == 2
|
|
|
|
assert o_shape == (num_tokens, lora_rank)
|
|
|
|
|
|
|
|
return {
|
|
|
|
'inputs': self.input,
|
|
|
|
'lora_a_weights': self.lora_weights_lst[0],
|
|
|
|
'output_tensor': self.output,
|
|
|
|
'lora_indices_tensor': self.token_lora_mapping,
|
|
|
|
'scaling': 1.0
|
|
|
|
}
|
|
|
|
|
|
|
|
def as_bgmv_expand_kwargs(self, add_inputs: bool):
|
|
|
|
assert len(self.lora_weights_lst) == 1
|
|
|
|
self.to_device(self.input.device)
|
|
|
|
|
|
|
|
_, num_tokens, _, _ = self.metadata()
|
|
|
|
# Sanity check shapes
|
|
|
|
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
|
|
|
0].shape, self.output.shape
|
|
|
|
# Expected input shape [num_tokens, lora_rank]
|
|
|
|
assert len(i_shape) == 2
|
|
|
|
assert i_shape[0] == num_tokens
|
|
|
|
lora_rank = i_shape[1]
|
|
|
|
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
|
|
|
|
assert len(lw_shape) == 3
|
|
|
|
assert lw_shape[2] == lora_rank
|
|
|
|
hidden_size = lw_shape[1]
|
|
|
|
# Expected output shape [num_tokens, hidden_size]
|
|
|
|
assert len(o_shape) == 2
|
|
|
|
assert o_shape == (num_tokens, hidden_size)
|
|
|
|
|
|
|
|
return {
|
|
|
|
'inputs': self.input,
|
|
|
|
'lora_b_weights': self.lora_weights_lst[0],
|
|
|
|
'output_tensor': self.output,
|
|
|
|
'lora_indices_tensor': self.token_lora_mapping,
|
|
|
|
'add_inputs': add_inputs
|
|
|
|
}
|
|
|
|
|
|
|
|
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
_, num_tokens, _, num_slices = self.metadata()
|
|
|
|
# Sanity check shapes
|
|
|
|
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
|
|
|
0].shape, self.output.shape
|
|
|
|
# Expected input shape [num_slices, num_tokens, lora_rank]
|
|
|
|
assert len(i_shape) == 3
|
|
|
|
assert i_shape[0] == num_slices
|
|
|
|
assert i_shape[1] == num_tokens
|
|
|
|
lora_rank = i_shape[2]
|
|
|
|
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
|
|
|
|
assert len(lw_shape) == 3
|
|
|
|
assert lw_shape[2] == lora_rank
|
|
|
|
hidden_size = lw_shape[1]
|
|
|
|
# Expected output shape [num_tokens, hidden_size * num_slices]
|
|
|
|
assert len(o_shape) == 2
|
|
|
|
assert o_shape == (num_tokens, hidden_size * num_slices)
|
|
|
|
|
|
|
|
self.to_device(self.input.device)
|
|
|
|
|
|
|
|
kwargs_list = []
|
|
|
|
for i in range(num_slices):
|
|
|
|
kwargs_list.append({
|
|
|
|
'inputs': self.input[i],
|
|
|
|
'lora_b_weights': self.lora_weights_lst[i],
|
|
|
|
'output_tensor': self.output,
|
|
|
|
'lora_indices_tensor': self.token_lora_mapping,
|
|
|
|
'slice_offset': i * hidden_size,
|
|
|
|
'slice_size': hidden_size,
|
|
|
|
'add_inputs': add_inputs,
|
|
|
|
})
|
|
|
|
return {'kwargs_list': kwargs_list}
|
|
|
|
|
|
|
|
def bench_fn_kwargs(self,
|
|
|
|
op_type: OpType,
|
|
|
|
add_inputs: Optional[bool] = None) -> Dict[str, Any]:
|
|
|
|
if op_type.is_shrink_fn():
|
|
|
|
assert add_inputs is None
|
|
|
|
else:
|
|
|
|
assert add_inputs is not None
|
|
|
|
|
|
|
|
if op_type == OpType.SGMV_SHRINK:
|
|
|
|
return self.as_sgmv_shrink_kwargs()
|
|
|
|
if op_type == OpType.SGMV_EXPAND:
|
|
|
|
return self.as_sgmv_expand_kwargs(add_inputs)
|
|
|
|
if op_type == OpType.BGMV_SHRINK:
|
|
|
|
return self.as_bgmv_shrink_kwargs()
|
|
|
|
if op_type == OpType.BGMV_EXPAND:
|
|
|
|
return self.as_bgmv_expand_kwargs(add_inputs)
|
|
|
|
if op_type == OpType.BGMV_EXPAND_SLICE:
|
|
|
|
return self.as_bgmv_expand_slice_kwargs(add_inputs)
|
|
|
|
raise ValueError(f"Unrecognized optype {self}")
|
|
|
|
|
|
|
|
def test_correctness(self, op_type: OpType,
|
|
|
|
expand_fn_add_inputs: Optional[bool]) -> bool:
|
|
|
|
"""
|
|
|
|
Test correctness of op_type implementation against a grouped gemm
|
|
|
|
reference implementation.
|
|
|
|
"""
|
|
|
|
seq_lens_cpu = self.seq_lens.to(device="cpu")
|
|
|
|
prompt_lora_mapping_cpu = self.prompt_lora_mapping.to(device="cpu")
|
|
|
|
ref_output = self.output.clone()
|
|
|
|
|
|
|
|
self.output.zero_()
|
|
|
|
op_type.bench_fn()(
|
|
|
|
**self.bench_fn_kwargs(op_type, expand_fn_add_inputs))
|
|
|
|
|
|
|
|
op_type.run_ref_group_gemm(
|
|
|
|
ref_output,
|
|
|
|
self.input,
|
|
|
|
self.lora_weights_lst,
|
|
|
|
seq_lens_cpu=seq_lens_cpu,
|
|
|
|
prompt_lora_mapping_cpu=prompt_lora_mapping_cpu,
|
|
|
|
scaling=1.0,
|
|
|
|
add_inputs=expand_fn_add_inputs)
|
|
|
|
|
|
|
|
rtol, atol = {
|
|
|
|
torch.float16: (6e-2, 6e-2),
|
|
|
|
torch.bfloat16: (6e-2, 6e-2),
|
|
|
|
torch.float32: (1e-2, 1e-2),
|
|
|
|
}[self.output.dtype]
|
|
|
|
|
|
|
|
return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol)
|
|
|
|
|
|
|
|
|
|
|
|
def bench_optype(ctx: BenchmarkContext,
|
|
|
|
arg_pool_size: int,
|
|
|
|
op_type: OpType,
|
|
|
|
cuda_graph_nops: Optional[int] = None,
|
|
|
|
expand_fn_add_inputs: Optional[bool] = None,
|
|
|
|
test_correctness: bool = False) -> TMeasurement:
|
|
|
|
|
|
|
|
assert arg_pool_size >= 1
|
|
|
|
if op_type.is_shrink_fn():
|
|
|
|
assert expand_fn_add_inputs is None
|
|
|
|
else:
|
|
|
|
assert expand_fn_add_inputs is not None
|
|
|
|
|
|
|
|
# BenchmarkContext -> BenchmarkTensors
|
|
|
|
bench_tensors : List[BenchmarkTensors] = \
|
|
|
|
[BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)]
|
|
|
|
for bt in bench_tensors:
|
|
|
|
bt.sanity_check()
|
|
|
|
|
|
|
|
# Test correctness of our implementation.
|
|
|
|
if test_correctness:
|
|
|
|
assert all([
|
|
|
|
bt.test_correctness(op_type, expand_fn_add_inputs)
|
|
|
|
for bt in bench_tensors
|
|
|
|
])
|
|
|
|
|
|
|
|
# BenchmarkTensors -> Dict (kwargs)
|
|
|
|
kwargs_list = [
|
|
|
|
bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs)
|
|
|
|
for bt in bench_tensors
|
|
|
|
]
|
|
|
|
|
|
|
|
# Clear LoRA optimization hash-maps.
|
|
|
|
_LORA_A_PTR_DICT.clear()
|
|
|
|
_LORA_B_PTR_DICT.clear()
|
|
|
|
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are setup
|
|
|
|
for kwargs in kwargs_list:
|
|
|
|
op_type.bench_fn()(**kwargs)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
# Merge into a single kwargs and qualify arguments as ArgPool
|
|
|
|
kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
|
|
|
|
for _kwargs in kwargs_list:
|
|
|
|
for k, v in _kwargs.items():
|
|
|
|
kwargs[k].values.append(v)
|
|
|
|
|
|
|
|
describe_args = (f"add_inputs={expand_fn_add_inputs}"
|
|
|
|
if expand_fn_add_inputs is not None else "")
|
|
|
|
description = (
|
|
|
|
f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})")
|
|
|
|
|
|
|
|
cuda_graph_params = None
|
|
|
|
if cuda_graph_nops:
|
|
|
|
cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
|
|
|
|
timer = None
|
|
|
|
with Bench(cuda_graph_params,
|
|
|
|
ctx.bench_label(), ctx.bench_sublabel(op_type), description,
|
|
|
|
op_type.bench_fn(), **kwargs) as bench:
|
|
|
|
timer = bench.run()
|
|
|
|
return timer
|
|
|
|
|
|
|
|
|
|
|
|
def bench_torch_mm(ctx: BenchmarkContext,
|
|
|
|
arg_pool_size: int,
|
|
|
|
op_type: OpType,
|
|
|
|
cuda_graph_nops: Optional[int] = None) -> TMeasurement:
|
|
|
|
"""
|
|
|
|
Benchmark basic torch.mm as a roofline.
|
|
|
|
|
|
|
|
When all the input tokens have the same LoRA ID, the LoRA kernels are just
|
|
|
|
a matmul. This torch.mm benchmark serves as a roofline for that case.
|
|
|
|
|
|
|
|
input op_type is used in determining the m, k, n dimensions for the matmul.
|
|
|
|
"""
|
|
|
|
|
|
|
|
batch_size, hidden_size, lora_rank, seq_length, dtype = (ctx.batch_size,
|
|
|
|
ctx.hidden_size,
|
|
|
|
ctx.lora_rank,
|
|
|
|
ctx.seq_length,
|
|
|
|
ctx.dtype)
|
|
|
|
|
|
|
|
m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank)
|
|
|
|
# For a fairer comparison.
|
|
|
|
n = n * ctx.num_slices
|
|
|
|
|
|
|
|
# Get matmul input and output tensors for A x B = C
|
|
|
|
As, Bs, Cs = [], [], []
|
|
|
|
for _ in range(arg_pool_size):
|
|
|
|
As.append(torch.rand((m, k), dtype=dtype).to("cuda"))
|
|
|
|
Bs.append(torch.rand((n, k), dtype=dtype).to("cuda").t())
|
|
|
|
Cs.append(torch.rand((m, n), dtype=dtype).to("cuda"))
|
|
|
|
|
|
|
|
# Make torch.mm kwargs
|
|
|
|
mm_kwargs = {'input': ArgPool(As), 'mat2': ArgPool(Bs), 'out': ArgPool(Cs)}
|
|
|
|
|
|
|
|
description = (
|
|
|
|
f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}"
|
|
|
|
f"x{dtype_to_str(dtype)}"
|
|
|
|
f"=>{dtype_to_str(dtype)})")
|
|
|
|
cuda_graph_params = None
|
|
|
|
if cuda_graph_nops:
|
|
|
|
cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
|
|
|
|
with Bench(cuda_graph_params, ctx.bench_label(),
|
|
|
|
ctx.bench_sublabel(op_type), description, torch.mm,
|
|
|
|
**mm_kwargs) as bench:
|
|
|
|
return bench.run()
|
|
|
|
|
|
|
|
|
|
|
|
# runner
|
|
|
|
def use_cuda_graph_recommendation() -> str:
|
|
|
|
return """
|
|
|
|
Triton kernels have a significant launch overhead with
|
|
|
|
launched directly via python. This overhead is more noticeable
|
|
|
|
for small the problem sizes. For these cases, it is recommended
|
|
|
|
to use the script with `--cuda-graph-nops N` to benchmark N
|
|
|
|
consecutive invocations of the benchmarking operations from
|
|
|
|
inside a CUDA Graph. Note that the returned measurement is for N
|
|
|
|
invocations of the operation.
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def print_timers(timers: List[TMeasurement],
|
|
|
|
args: Optional[argparse.Namespace] = None):
|
|
|
|
compare = TBenchmark.Compare(timers)
|
|
|
|
compare.print()
|
|
|
|
|
|
|
|
if args and args.cuda_graph_nops:
|
|
|
|
print(
|
|
|
|
f"Note : The timings reported above is for {args.cuda_graph_nops} "
|
|
|
|
"consecutive invocations of the benchmarking functions. "
|
|
|
|
f"Please divide by {args.cuda_graph_nops} for single invocation "
|
|
|
|
"timings.")
|
|
|
|
|
|
|
|
print("Note on Comparison with torch.mm : The torch.mm numbers are "
|
|
|
|
"benchmark numbers of a simple matmul emulating the single lora "
|
|
|
|
"case. It is provided as a roofline for comparing our LoRA Kernel "
|
|
|
|
"implementations. It is expected that the LoRA kernels will be "
|
|
|
|
"slower than torch.mm in cases where num_loras is big. But for "
|
|
|
|
"small num_loras the goal should be to match the torch.mm numbers.")
|
|
|
|
|
|
|
|
|
|
|
|
def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
|
|
|
|
|
|
|
|
if args.cuda_graph_nops is not None:
|
|
|
|
assert args.cuda_graph_nops > 0
|
|
|
|
print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA "
|
|
|
|
"Graph")
|
|
|
|
else:
|
|
|
|
print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}")
|
|
|
|
|
|
|
|
timers = []
|
|
|
|
for bench_ctx in bench_ctxs:
|
|
|
|
for seq_len in args.seq_lengths:
|
|
|
|
bench_ops: List[OpType] = []
|
|
|
|
if seq_len == 1:
|
|
|
|
# bench all decode ops
|
|
|
|
bench_ops = [op for op in args.op_types if op.is_decode_op()]
|
|
|
|
else:
|
|
|
|
# bench all prefill ops
|
|
|
|
bench_ops = [op for op in args.op_types if op.is_prefill_op()]
|
|
|
|
|
|
|
|
seq_len_timers = []
|
|
|
|
for bench_op in bench_ops:
|
|
|
|
for num_slices in bench_op.num_slices():
|
|
|
|
_ctx = bench_ctx.with_seq_length(seq_len).with_num_slices(
|
|
|
|
num_slices)
|
|
|
|
# Benchmark torch.mm as a roofline
|
|
|
|
seq_len_timers.append(
|
|
|
|
bench_torch_mm(_ctx, args.arg_pool_size, bench_op,
|
|
|
|
args.cuda_graph_nops))
|
|
|
|
|
|
|
|
# Benchmark bench_op
|
|
|
|
expand_fn_add_inputs = [
|
|
|
|
None
|
|
|
|
] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs
|
|
|
|
for add_input_arg in expand_fn_add_inputs:
|
|
|
|
seq_len_timers.append(
|
|
|
|
bench_optype(_ctx, args.arg_pool_size, bench_op,
|
|
|
|
args.cuda_graph_nops, add_input_arg,
|
|
|
|
args.test_correctness))
|
|
|
|
|
|
|
|
print_timers(seq_len_timers)
|
|
|
|
timers.extend(seq_len_timers)
|
|
|
|
|
|
|
|
# Result stdout dump
|
|
|
|
print("== All Results ====")
|
|
|
|
print_timers(timers, args)
|
|
|
|
|
|
|
|
if args.output_directory:
|
|
|
|
# Result file dump
|
|
|
|
od = Path(args.output_directory)
|
|
|
|
if not od.exists():
|
|
|
|
od.mkdir()
|
|
|
|
|
|
|
|
timestamp = int(time.time())
|
|
|
|
pkl_file = od / f"lora_bench-{timestamp}.pkl"
|
|
|
|
print(f"Writing benchmarks to {pkl_file}")
|
|
|
|
with open(pkl_file, "wb") as f:
|
|
|
|
pickle.dump(timers, f)
|
|
|
|
|
|
|
|
|
|
|
|
def as_benchmark_contexts(hidden_sizes: List[int], lora_ranks: List[int],
|
|
|
|
args: argparse.Namespace) -> List[BenchmarkContext]:
|
|
|
|
|
|
|
|
ctxs: List[BenchmarkContext] = []
|
|
|
|
for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa
|
|
|
|
args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras,
|
|
|
|
args.sort_by_lora_id):
|
|
|
|
ctxs.append(
|
|
|
|
BenchmarkContext(
|
|
|
|
batch_size=batch_size,
|
|
|
|
hidden_size=hidden_size,
|
|
|
|
lora_rank=lora_rank,
|
|
|
|
num_loras=num_loras,
|
|
|
|
num_active_loras=args.num_active_loras
|
|
|
|
if args.num_active_loras else num_loras,
|
|
|
|
# To be filled based on the OpType to benchmark
|
|
|
|
seq_length=None,
|
|
|
|
sort_by_lora_id=sort_by_lora_id,
|
|
|
|
dtype=args.dtype,
|
|
|
|
# To be filled based on the OpType to benchmark
|
|
|
|
num_slices=None))
|
|
|
|
|
|
|
|
return ctxs
|
|
|
|
|
|
|
|
|
|
|
|
def run_list_bench(args: argparse.Namespace):
|
|
|
|
print(args)
|
|
|
|
|
|
|
|
print("List bench :\n"
|
|
|
|
f" Hidden Sizes {args.hidden_sizes}"
|
|
|
|
f" LoRA Ranks {args.lora_ranks}")
|
|
|
|
|
|
|
|
# Get all benchmarking contexts
|
|
|
|
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts(
|
|
|
|
hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args)
|
|
|
|
|
|
|
|
run(args, bench_contexts)
|
|
|
|
|
|
|
|
|
|
|
|
def run_range_bench(args: argparse.Namespace):
|
|
|
|
print(args)
|
|
|
|
|
|
|
|
hidden_sizes = list(
|
|
|
|
range(args.hidden_sizes_start, args.hidden_sizes_end + 1,
|
|
|
|
args.hidden_sizes_increment))
|
|
|
|
lora_ranks = list(
|
|
|
|
range(args.lora_ranks_start, args.lora_ranks_end + 1,
|
|
|
|
args.lora_ranks_increment))
|
|
|
|
|
|
|
|
print("Range bench :\n"
|
|
|
|
f" Hidden Sizes {hidden_sizes}"
|
|
|
|
f" LoRA Ranks {lora_ranks}")
|
|
|
|
|
|
|
|
# Get all benchmarking contexts
|
|
|
|
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts(
|
|
|
|
hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args)
|
|
|
|
|
|
|
|
run(args, bench_contexts)
|
|
|
|
|
|
|
|
|
|
|
|
def run_model_bench(args: argparse.Namespace):
|
|
|
|
print(args)
|
|
|
|
|
|
|
|
def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]:
|
|
|
|
hidden_sizes = set()
|
|
|
|
for KN, tp_split_dim in WEIGHT_SHAPES[model]:
|
|
|
|
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
|
|
|
hidden_sizes.add(KN[1])
|
|
|
|
return hidden_sizes
|
|
|
|
|
|
|
|
# Get all hidden sizes
|
|
|
|
hidden_sizes: set[int] = set()
|
|
|
|
for model_name, tp_size in product(args.models, args.tp_sizes):
|
|
|
|
hidden_sizes = hidden_sizes.union(
|
|
|
|
hidden_sizes_from_model(model_name, tp_size))
|
|
|
|
|
|
|
|
print("Model bench :\n"
|
|
|
|
f" Hidden Sizes {hidden_sizes}"
|
|
|
|
f" LoRA Ranks {args.lora_ranks}")
|
|
|
|
|
|
|
|
# Get all benchmarking contexts
|
|
|
|
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts(
|
|
|
|
hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args)
|
|
|
|
|
|
|
|
run(args, bench_contexts)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
|
|
def to_torch_dtype(dt):
|
|
|
|
if dt == "torch.float16":
|
|
|
|
return torch.float16
|
|
|
|
if dt == "torch.bfloat16":
|
|
|
|
return torch.bfloat16
|
|
|
|
raise ValueError("unsupported dtype")
|
|
|
|
|
|
|
|
def get_bool(s: str) -> bool:
|
|
|
|
return s.lower() in ['true', '1']
|
|
|
|
|
|
|
|
def add_common_command_args(p: argparse.ArgumentParser):
|
|
|
|
p.add_argument(
|
|
|
|
"--dtype",
|
|
|
|
type=to_torch_dtype,
|
|
|
|
required=True,
|
|
|
|
help="Available options are ['torch.float16', 'torch.bfloat16']")
|
|
|
|
|
|
|
|
p.add_argument(
|
|
|
|
"--arg-pool-size",
|
|
|
|
type=int,
|
|
|
|
default=32,
|
|
|
|
help="Run profiles with a pool of input/output/meta tensors instead"
|
|
|
|
"of simply reusing the same tensors for all runs. A bigger arg-pool"
|
|
|
|
"mitigates hardware caching effects during benchmarking.")
|
|
|
|
|
|
|
|
p.add_argument(
|
|
|
|
"--cuda-graph-nops",
|
|
|
|
type=int,
|
|
|
|
help=("when set profiling is done using cudagraph, "
|
|
|
|
"with the given number of operations in a graph."
|
|
|
|
"Note that the measurement returned is the time "
|
|
|
|
"taken for N consecutive executions of the benchmarking "
|
|
|
|
"functions, where N is the value of this argument."))
|
|
|
|
p.add_argument("--num-loras",
|
|
|
|
nargs="+",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_NUM_LORAS)
|
|
|
|
p.add_argument("--num-active-loras",
|
|
|
|
type=int,
|
|
|
|
default=None,
|
|
|
|
help="Active LoRAs. When None, all LoRAs are active")
|
|
|
|
p.add_argument("--sort-by-lora-id",
|
|
|
|
nargs="+",
|
|
|
|
type=get_bool,
|
|
|
|
default=DEFAULT_SORT_BY_LORA_IDS)
|
|
|
|
p.add_argument("--op-types",
|
|
|
|
nargs="+",
|
|
|
|
type=OpType.from_str,
|
|
|
|
default=list(OpType))
|
|
|
|
p.add_argument('--seq-lengths',
|
|
|
|
nargs="+",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_SEQ_LENGTHS)
|
|
|
|
p.add_argument("--batch-sizes",
|
|
|
|
nargs="+",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_BATCH_SIZES)
|
|
|
|
p.add_argument("--expand-fn-add-inputs",
|
|
|
|
nargs="+",
|
|
|
|
type=get_bool,
|
|
|
|
default=DEFAULT_EXPAND_FN_ADD_INPUTS)
|
|
|
|
p.add_argument(
|
|
|
|
'-o',
|
|
|
|
'--output-directory',
|
|
|
|
type=str,
|
|
|
|
help=("Output directory to store a the list of benchmarking"
|
|
|
|
"TMeasurement objects as a pickle file"))
|
|
|
|
|
|
|
|
p.add_argument(
|
|
|
|
"--test-correctness",
|
|
|
|
action='store_true',
|
|
|
|
help=("When enabled, the benchmarking functions are tested"
|
|
|
|
"for correctness before the actual benchmarking"))
|
|
|
|
|
|
|
|
parser = FlexibleArgumentParser(
|
|
|
|
description=f"""
|
|
|
|
Benchmark LoRA kernels:
|
|
|
|
{use_cuda_graph_recommendation()}
|
|
|
|
|
|
|
|
list_bench example:
|
|
|
|
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
|
|
|
|
|
|
|
model_bench example:
|
|
|
|
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
|
|
|
|
|
|
|
range_bench example:
|
|
|
|
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
|
|
|
|
""", # noqa: E501
|
|
|
|
formatter_class=argparse.RawTextHelpFormatter)
|
|
|
|
|
|
|
|
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
|
|
|
|
|
|
|
list_parser = subparsers.add_parser("list_bench")
|
|
|
|
list_parser.add_argument("--hidden-sizes",
|
|
|
|
nargs="+",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_HIDDEN_SIZES)
|
|
|
|
list_parser.add_argument("--lora-ranks",
|
|
|
|
nargs="+",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_LORA_RANKS)
|
|
|
|
add_common_command_args(list_parser)
|
|
|
|
list_parser.set_defaults(func=run_list_bench)
|
|
|
|
|
|
|
|
range_parser = subparsers.add_parser("range_bench")
|
|
|
|
range_parser.add_argument("--hidden-sizes-start", type=int, required=True)
|
|
|
|
range_parser.add_argument("--hidden-sizes-end", type=int, required=True)
|
|
|
|
range_parser.add_argument("--hidden-sizes-increment",
|
|
|
|
type=int,
|
|
|
|
required=True)
|
|
|
|
range_parser.add_argument("--lora-ranks-start", type=int, required=True)
|
|
|
|
range_parser.add_argument("--lora-ranks-end", type=int, required=True)
|
|
|
|
range_parser.add_argument("--lora-ranks-increment",
|
|
|
|
type=int,
|
|
|
|
required=True)
|
|
|
|
add_common_command_args(range_parser)
|
|
|
|
range_parser.set_defaults(func=run_range_bench)
|
|
|
|
|
|
|
|
model_parser = subparsers.add_parser("model_bench")
|
|
|
|
model_parser.add_argument("--models",
|
|
|
|
nargs="+",
|
|
|
|
type=str,
|
|
|
|
default=DEFAULT_MODELS,
|
|
|
|
choices=WEIGHT_SHAPES.keys())
|
|
|
|
model_parser.add_argument("--tp-sizes",
|
|
|
|
nargs="+",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_TP_SIZES)
|
|
|
|
model_parser.add_argument("--lora-ranks",
|
|
|
|
nargs="+",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_LORA_RANKS)
|
|
|
|
add_common_command_args(model_parser)
|
|
|
|
model_parser.set_defaults(func=run_model_bench)
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
args.func(args)
|