# SPDX-License-Identifier: Apache-2.0 import argparse import copy import json import pickle import time from dataclasses import dataclass from enum import Enum, auto from itertools import product from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement from utils import ArgPool, Bench, CudaGraphBenchParams from weight_shapes import WEIGHT_SHAPES from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_TP_SIZES = [1] DEFAULT_BATCH_SIZES = [ 1, 16, 32, 64, 128, 192, 256, 320, 384, 448, 512, 640, 768, 896, 1024, 2048, 3072, 4096, 5120, 6144, 7168, 8192 ] DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384] DEFAULT_LORA_RANKS = [16] DEFAULT_NUM_LORAS = [1, 2, 3, 4] DEFAULT_SORT_BY_LORA_IDS = [False, True] DEFAULT_SEQ_LENGTHS = [1] DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False] # Utilities def dtype_to_str(dtype: torch.dtype): if dtype == torch.float16: return "f16" if dtype == torch.bfloat16: return "bf16" if dtype == torch.float32: return "f32" raise ValueError(f"Unsupported dtype {dtype}") def make_rand_lora_weight_tensor(k: int, n: int, num_loras: int, dtype: torch.dtype, device: str = "cuda") -> torch.Tensor: # LoRA weights column major return torch.rand((num_loras, n, k), dtype=dtype).to(device) def make_rand_tensors( a_shape: Tuple[int], b_shape: Tuple[int], c_shape: Tuple[int], a_dtype: torch.dtype, b_dtype: torch.dtype, c_dtype: torch.dtype, num_slices: int, device: str = "cuda", ) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: """ Make LoRA input/output matrices. """ A = torch.rand(a_shape, dtype=a_dtype).to(device) # LoRA weights column major Bs = [ torch.rand(b_shape, dtype=b_dtype).to(device) for _ in range(num_slices) ] C = torch.zeros(c_shape, dtype=c_dtype).to(device) return A, Bs, C def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, sort_by_lora_id: bool, device: str) -> torch.Tensor: """ All prompts are mapped to a Lora ID in range [0, num_active_loras). where 0 refers to first lora, 1 refers to second lora and so on. """ assert num_active_loras > 0 if not sort_by_lora_id: return torch.randint(0, num_active_loras, (num_prompts, ), dtype=torch.long) # Divide LoRAs equally and in order. part_size = num_prompts // num_active_loras part_size = max(part_size, 1) lora_id = 0 prompt_lora_mapping = [] while len(prompt_lora_mapping) < num_prompts: prompt_lora_mapping.extend([lora_id] * part_size) lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id return torch.tensor(prompt_lora_mapping[:num_prompts], dtype=torch.long, device=device) def make_token_lora_mapping(num_tokens: int, num_prompts: int, prompt_lora_mapping: torch.Tensor, seq_len_tensor: torch.Tensor, device: str): """ Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor """ assert prompt_lora_mapping.shape[0] == num_prompts # token to lora index mapping token_lora_mapping = [0] * num_tokens current_offset = 0 for b_id in range(num_prompts): lora_index = prompt_lora_mapping[b_id].item() s = current_offset e = s + seq_len_tensor[b_id].item() token_lora_mapping[s:e] = [lora_index] * (e - s) current_offset += seq_len_tensor[b_id].item() return torch.tensor(token_lora_mapping, dtype=torch.long, device=device) def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor, lora_weights: List[torch.Tensor], seq_lens_cpu: torch.Tensor, prompt_lora_mapping_cpu: torch.Tensor, scaling: float, add_inputs: Optional[bool]): """ Torch group gemm reference implementation to test correctness of benchmarking operations. """ batches = seq_lens_cpu.size(0) out_list = [] current_offset = 0 for lora_index, b_length in zip(range(batches), seq_lens_cpu): x = input[current_offset:b_length + current_offset, :] current_offset += b_length w = lora_weights[prompt_lora_mapping_cpu[lora_index]] result = torch.nn.functional.linear(x, w) result *= scaling out_list.append(result) torch.cat(out_list, dim=0) cat_result = torch.cat(out_list, dim=0) if add_inputs: ref_out += cat_result else: ref_out.copy_(cat_result) class OpType(Enum): """ LoRA Ops to benchmark and its properties. """ SGMV_SHRINK = auto() BGMV_SHRINK = auto() SGMV_EXPAND = auto() BGMV_EXPAND = auto() BGMV_EXPAND_SLICE = auto() @staticmethod def from_str(s: str) -> "OpType": if s.lower() == 'sgmv_shrink': return OpType.SGMV_SHRINK if s.lower() == 'sgmv_expand': return OpType.SGMV_EXPAND if s.lower() == 'bgmv_shrink': return OpType.BGMV_SHRINK if s.lower() == 'bgmv_expand': return OpType.BGMV_EXPAND if s.lower() == "bgmv_expand_slice": return OpType.BGMV_EXPAND_SLICE raise ValueError(f"Unrecognized str {s} to convert to OpType") def is_shrink_fn(self) -> bool: return self in [OpType.SGMV_SHRINK, OpType.BGMV_SHRINK] def is_expand_fn(self) -> bool: return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND] def is_prefill_op(self) -> bool: return self in [OpType.SGMV_SHRINK, OpType.SGMV_EXPAND] def is_decode_op(self) -> bool: return self in [ OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE ] def is_expand_slice_fn(self) -> bool: return self in [OpType.BGMV_EXPAND_SLICE] def num_slices(self) -> List[int]: if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]: # SGMV kernels supports slices return [1, 2, 3] if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]: return [1] if self in [OpType.BGMV_EXPAND_SLICE]: return [2, 3] raise ValueError(f"Unrecognized OpType {self}") def mkn(self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int) -> Tuple[int, int, int]: num_tokens = batch_size * seq_length if self.is_shrink_fn(): m = num_tokens k = hidden_size n = lora_rank else: assert self.is_expand_fn() or self.is_expand_slice_fn() m = num_tokens k = lora_rank n = hidden_size return m, k, n def matmul_dtypes( self, op_dtype: torch.dtype ) -> Tuple[torch.dtype, torch.dtype, torch.dtype]: """ return a type, b type and c type for A x B = C """ if self.is_shrink_fn(): return op_dtype, op_dtype, torch.float32 else: assert self.is_expand_fn() or self.is_expand_slice_fn() return torch.float32, op_dtype, op_dtype def matmul_shapes( self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int, num_loras: int, num_slices: int) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]: """ Given num_slices, return the shapes of the A, B, and C matrices in A x B = C, for the op_type """ m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank) b_shape = (num_loras, n, k) # col-major if self == OpType.SGMV_SHRINK: # SGMV shrink supports num_slices inherently in the kernel return ((m, k), b_shape, (num_slices, m, n)) if self == OpType.SGMV_EXPAND: # SGMV expand supports num_slices inherently in the kernel return ((num_slices, m, k), b_shape, (m, n * num_slices)) if self == OpType.BGMV_SHRINK: return ((m, k), b_shape, (m, n)) if self == OpType.BGMV_EXPAND: return ((m, k), b_shape, (m, n)) if self == OpType.BGMV_EXPAND_SLICE: return ((num_slices, m, k), b_shape, (m, n * num_slices)) raise ValueError(f"Unrecognized op_type {self}") def bench_fn(self) -> Callable: def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]): for x in kwargs_list: bgmv_expand_slice(**x) if self == OpType.SGMV_SHRINK: return sgmv_shrink if self == OpType.SGMV_EXPAND: return sgmv_expand if self == OpType.BGMV_SHRINK: return bgmv_shrink if self == OpType.BGMV_EXPAND: return bgmv_expand if self == OpType.BGMV_EXPAND_SLICE: return emulate_bgmv_expand_slice raise ValueError(f"Unrecognized optype {self}") def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor, lora_weights: List[torch.Tensor], **kwargs) -> Callable: """Each benchmark operation expected the input, lora_weights and outputs in a slightly different format. Refer to self.matmul_shapes(). run_ref_group_gemm accounts for those differences in executing a reference group gemm for correctness testing. """ w_dtype = lora_weights[0].dtype num_slices = len(lora_weights) if self == OpType.SGMV_SHRINK: for slice_idx in range(num_slices): ref_group_gemm(ref_out=output[slice_idx, :], input=input, lora_weights=lora_weights[slice_idx], **kwargs) if self == OpType.SGMV_EXPAND: hidden_size = lora_weights[0].shape[1] for slice_idx in range(num_slices): slice_offset = slice_idx * hidden_size ref_group_gemm( ref_out=output[:, slice_offset:slice_offset + hidden_size], input=input[slice_idx].clone().to(dtype=w_dtype), lora_weights=lora_weights[slice_idx], **kwargs) if self == OpType.BGMV_SHRINK: assert num_slices == 1 ref_group_gemm(ref_out=output, input=input, lora_weights=lora_weights[0], **kwargs) if self == OpType.BGMV_EXPAND: assert num_slices == 1 ref_group_gemm(ref_out=output, input=input.clone().to(dtype=w_dtype), lora_weights=lora_weights[0], **kwargs) if self == OpType.BGMV_EXPAND_SLICE: hidden_size = lora_weights[0].shape[1] for slice_idx in range(num_slices): slice_offset = slice_idx * hidden_size ref_group_gemm( ref_out=output[:, slice_offset:slice_offset + hidden_size], input=input[slice_idx].clone().to(dtype=w_dtype), lora_weights=lora_weights[slice_idx], **kwargs) raise ValueError(f"Unrecognized optype {self}") @dataclass class BenchmarkContext: """ LoRA benchmark context """ batch_size: int hidden_size: int num_loras: int num_active_loras: int lora_rank: int sort_by_lora_id: bool dtype: torch.dtype seq_length: Optional[int] = None num_slices: Optional[int] = None # num_slices for slice based ops def with_seq_length(self, seq_length: int) -> "BenchmarkContext": ctx = copy.copy(self) ctx.seq_length = seq_length return ctx def with_num_slices(self, num_slices: int) -> "BenchmarkContext": ctx = copy.copy(self) ctx.num_slices = num_slices return ctx def bench_label(self) -> str: return f"lora-{self.dtype}" def bench_sublabel(self, op_type: OpType) -> str: m, k, n = op_type.mkn(self.batch_size, self.seq_length, self.hidden_size, self.lora_rank) desc = { 'bs': self.batch_size, 'sl': self.seq_length, 'm': m, 'k': k, 'n': n, 'num_loras': self.num_loras, 'sort_by_lora': self.sort_by_lora_id, 'num_slices': self.num_slices, } return json.dumps(desc) @dataclass class BenchmarkTensors: """ Input/Output tensors used for benchmarks """ # matmul tensors input: torch.Tensor lora_weights_lst: List[torch.Tensor] output: torch.Tensor # metadata tensors seq_lens: torch.Tensor seq_start_loc: torch.Tensor prompt_lora_mapping: torch.Tensor token_lora_mapping: torch.Tensor def io_types(self) -> str: return (f"{dtype_to_str(self.input.dtype)}x" f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>" f"{dtype_to_str(self.output.dtype)}") @staticmethod def make(ctx: BenchmarkContext, op_type: OpType, device: str = "cuda") -> "BenchmarkTensors": # Make input / output matmul tensors. a_shape, b_shape, c_shape = op_type.matmul_shapes( ctx.batch_size, ctx.seq_length, ctx.hidden_size, ctx.lora_rank, ctx.num_loras, ctx.num_slices) a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) input_tensor, lora_weights, output_tensor = \ make_rand_tensors(a_shape, b_shape, c_shape, a_type, b_type, c_type, num_slices = ctx.num_slices) # Make metadata tensors. # Keep the metadata tensors in the CPU for further processing if needed. # The tensors get moved to the GPU before benchmarking. assert ctx.num_active_loras <= ctx.num_loras total_tokens = ctx.batch_size * ctx.seq_length # Prepare seq lens tensor seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1, (ctx.batch_size, )) # Prepare seq_start_loc tensor seq_start_loc_tensor = torch.cumsum(torch.tensor( [0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), dim=0) assert total_tokens == seq_len_tensor.sum() # Prepare prompt lora indices tensor prompt_lora_indices_tensor = make_prompt_lora_mapping( ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu") # Prepare token lora indices tensor token_lora_indices_tensor = make_token_lora_mapping( total_tokens, ctx.batch_size, prompt_lora_indices_tensor, seq_len_tensor, "cpu") return BenchmarkTensors(input_tensor, lora_weights, output_tensor, seq_len_tensor, seq_start_loc_tensor, prompt_lora_indices_tensor, token_lora_indices_tensor) def sanity_check(self) -> None: """ Fails asserts when non-conformality is detected. """ num_tokens = self.input.shape[-2] # check metadata tensors assert torch.sum(self.seq_lens) == num_tokens num_seqs = self.seq_lens.shape[0] assert self.seq_start_loc.shape[0] == num_seqs assert self.prompt_lora_mapping.shape[0] == num_seqs assert self.token_lora_mapping.shape[0] == num_tokens def to_device(self, device: str): """ Transfer tensors to device if the tensors aren't already on the device """ def to_device(tensor: torch.Tensor): if tensor.device != device: tensor = tensor.to(device=device) return tensor self.input = to_device(self.input) self.output = to_device(self.output) self.seq_lens = to_device(self.seq_lens) self.seq_start_loc = to_device(self.seq_start_loc) self.prompt_lora_mapping = to_device(self.prompt_lora_mapping) self.token_lora_mapping = to_device(self.token_lora_mapping) for i in range(len(self.lora_weights_lst)): self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i]) def metadata(self) -> Tuple[int, int, int]: """ Return num_seqs, num_tokens and max_seq_len """ num_seqs = self.seq_lens.shape[0] num_tokens = self.token_lora_mapping.shape[0] max_seq_len = torch.max(self.seq_lens).item() num_slices = len(self.lora_weights_lst) return num_seqs, num_tokens, max_seq_len, num_slices def convert_to_sgmv_benchmark_tensors(self): """ For sgmv punica kernels, when consecutive sequences have the same LoRA ID, we just merge them together. This happens in punica.py::compute_metadata """ # Collapse seq_lens and seq_start_loc _, seq_lens = torch.unique_consecutive(self.token_lora_mapping, return_counts=True) cum_result = torch.cumsum(seq_lens, dim=0) seq_start_loc = torch.zeros_like(seq_lens) seq_start_loc[1:].copy_(cum_result[:-1]) # Collapse prompt mapping prompt_lora_mapping = torch.unique_consecutive( self.prompt_lora_mapping) assert torch.sum(seq_lens) == torch.sum(self.seq_lens), \ f"dont match - new {torch.sum(seq_lens)} vs {torch.sum(self.seq_lens)}" self.prompt_lora_mapping = prompt_lora_mapping.to( dtype=self.prompt_lora_mapping.dtype) self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype) self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype) def as_sgmv_shrink_kwargs(self) -> Dict[str, Any]: self.convert_to_sgmv_benchmark_tensors() self.sanity_check() self.to_device(self.input.device) num_seqs, num_tokens, max_seq_len, num_slices = self.metadata() # Sanity check matrix shapes. i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ 0].shape, self.output.shape # Expected input shape [num_tokens, hidden_size] assert len(i_shape) == 2 assert i_shape[0] == num_tokens hidden_size = i_shape[1] # Expected lora weight shape [num_loras, lora_rank, hidden_size] assert len(lw_shape) == 3 assert lw_shape[2] == hidden_size lora_rank = lw_shape[1] # Expected output shape [num_slices, num_tokens, lora_rank] assert len(o_shape) == 3 assert o_shape == (num_slices, num_tokens, lora_rank) return { 'inputs': self.input, 'lora_a_weights': self.lora_weights_lst, 'output_tensor': self.output, 'b_seq_start_loc': self.seq_start_loc, 'seq_len_tensor': self.seq_lens, 'lora_indices_tensor': self.prompt_lora_mapping, 'batches': num_seqs, 'max_seq_length': max_seq_len, 'token_nums': num_tokens, 'scaling': 1.0, } def as_sgmv_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]: self.convert_to_sgmv_benchmark_tensors() self.sanity_check() self.to_device(self.input.device) num_seqs, num_tokens, max_seq_len, num_slices = self.metadata() # Sanity check matrix shapes. i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ 0].shape, self.output.shape # Expected input shape : [num_slices, num_tokens, lora_rank] assert len(i_shape) == 3 assert i_shape[0] == num_slices assert i_shape[1] == num_tokens lora_rank = i_shape[2] # Expected lora weight shape : [num_lora, hidden_size, lora_rank] assert len(lw_shape) == 3 assert lw_shape[2] == lora_rank hidden_size = lw_shape[1] # Expected output shape : [num_tokens, hidden_size * num_slices] assert len(o_shape) == 2 assert o_shape == (num_tokens, hidden_size * num_slices) return { 'inputs': self.input, 'lora_b_weights': self.lora_weights_lst, 'output_tensor': self.output, 'b_seq_start_loc': self.seq_start_loc, 'seq_len_tensor': self.seq_lens, 'lora_indices_tensor': self.prompt_lora_mapping, 'batches': num_seqs, 'max_seq_length': max_seq_len, 'token_nums': num_tokens, 'offset_start': 0, 'add_inputs': add_inputs, } def as_bgmv_shrink_kwargs(self) -> Dict[str, Any]: assert len(self.lora_weights_lst) == 1 self.to_device(self.input.device) _, num_tokens, _, _ = self.metadata() # Sanity check shapes i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ 0].shape, self.output.shape # Expected input shape [num_tokens, hidden_size] assert len(i_shape) == 2 assert i_shape[0] == num_tokens hidden_size = i_shape[1] # Expected lora weight shape [num_loras, lora_rank, hidden_size] assert len(lw_shape) == 3 assert lw_shape[2] == hidden_size lora_rank = lw_shape[1] # Expected output shape [num_tokens, lora_rank] assert len(o_shape) == 2 assert o_shape == (num_tokens, lora_rank) return { 'inputs': self.input, 'lora_a_weights': self.lora_weights_lst[0], 'output_tensor': self.output, 'lora_indices_tensor': self.token_lora_mapping, 'scaling': 1.0 } def as_bgmv_expand_kwargs(self, add_inputs: bool): assert len(self.lora_weights_lst) == 1 self.to_device(self.input.device) _, num_tokens, _, _ = self.metadata() # Sanity check shapes i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ 0].shape, self.output.shape # Expected input shape [num_tokens, lora_rank] assert len(i_shape) == 2 assert i_shape[0] == num_tokens lora_rank = i_shape[1] # Expected lora weight shape [num_loras, hidden_size, lora_rank] assert len(lw_shape) == 3 assert lw_shape[2] == lora_rank hidden_size = lw_shape[1] # Expected output shape [num_tokens, hidden_size] assert len(o_shape) == 2 assert o_shape == (num_tokens, hidden_size) return { 'inputs': self.input, 'lora_b_weights': self.lora_weights_lst[0], 'output_tensor': self.output, 'lora_indices_tensor': self.token_lora_mapping, 'add_inputs': add_inputs } def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]: _, num_tokens, _, num_slices = self.metadata() # Sanity check shapes i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ 0].shape, self.output.shape # Expected input shape [num_slices, num_tokens, lora_rank] assert len(i_shape) == 3 assert i_shape[0] == num_slices assert i_shape[1] == num_tokens lora_rank = i_shape[2] # Expected lora weight shape [num_loras, hidden_size, lora_rank] assert len(lw_shape) == 3 assert lw_shape[2] == lora_rank hidden_size = lw_shape[1] # Expected output shape [num_tokens, hidden_size * num_slices] assert len(o_shape) == 2 assert o_shape == (num_tokens, hidden_size * num_slices) self.to_device(self.input.device) kwargs_list = [] for i in range(num_slices): kwargs_list.append({ 'inputs': self.input[i], 'lora_b_weights': self.lora_weights_lst[i], 'output_tensor': self.output, 'lora_indices_tensor': self.token_lora_mapping, 'slice_offset': i * hidden_size, 'slice_size': hidden_size, 'add_inputs': add_inputs, }) return {'kwargs_list': kwargs_list} def bench_fn_kwargs(self, op_type: OpType, add_inputs: Optional[bool] = None) -> Dict[str, Any]: if op_type.is_shrink_fn(): assert add_inputs is None else: assert add_inputs is not None if op_type == OpType.SGMV_SHRINK: return self.as_sgmv_shrink_kwargs() if op_type == OpType.SGMV_EXPAND: return self.as_sgmv_expand_kwargs(add_inputs) if op_type == OpType.BGMV_SHRINK: return self.as_bgmv_shrink_kwargs() if op_type == OpType.BGMV_EXPAND: return self.as_bgmv_expand_kwargs(add_inputs) if op_type == OpType.BGMV_EXPAND_SLICE: return self.as_bgmv_expand_slice_kwargs(add_inputs) raise ValueError(f"Unrecognized optype {self}") def test_correctness(self, op_type: OpType, expand_fn_add_inputs: Optional[bool]) -> bool: """ Test correctness of op_type implementation against a grouped gemm reference implementation. """ seq_lens_cpu = self.seq_lens.to(device="cpu") prompt_lora_mapping_cpu = self.prompt_lora_mapping.to(device="cpu") ref_output = self.output.clone() self.output.zero_() op_type.bench_fn()( **self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) op_type.run_ref_group_gemm( ref_output, self.input, self.lora_weights_lst, seq_lens_cpu=seq_lens_cpu, prompt_lora_mapping_cpu=prompt_lora_mapping_cpu, scaling=1.0, add_inputs=expand_fn_add_inputs) rtol, atol = { torch.float16: (6e-2, 6e-2), torch.bfloat16: (6e-2, 6e-2), torch.float32: (1e-2, 1e-2), }[self.output.dtype] return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol) def bench_optype(ctx: BenchmarkContext, arg_pool_size: int, op_type: OpType, cuda_graph_nops: Optional[int] = None, expand_fn_add_inputs: Optional[bool] = None, test_correctness: bool = False) -> TMeasurement: assert arg_pool_size >= 1 if op_type.is_shrink_fn(): assert expand_fn_add_inputs is None else: assert expand_fn_add_inputs is not None # BenchmarkContext -> BenchmarkTensors bench_tensors : List[BenchmarkTensors] = \ [BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)] for bt in bench_tensors: bt.sanity_check() # Test correctness of our implementation. if test_correctness: assert all([ bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors ]) # BenchmarkTensors -> Dict (kwargs) kwargs_list = [ bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) for bt in bench_tensors ] # Clear LoRA optimization hash-maps. _LORA_A_PTR_DICT.clear() _LORA_B_PTR_DICT.clear() # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are setup for kwargs in kwargs_list: op_type.bench_fn()(**kwargs) torch.cuda.synchronize() # Merge into a single kwargs and qualify arguments as ArgPool kwargs = {k: ArgPool([]) for k in kwargs_list[0]} for _kwargs in kwargs_list: for k, v in _kwargs.items(): kwargs[k].values.append(v) describe_args = (f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else "") description = ( f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})") cuda_graph_params = None if cuda_graph_nops: cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) timer = None with Bench(cuda_graph_params, ctx.bench_label(), ctx.bench_sublabel(op_type), description, op_type.bench_fn(), **kwargs) as bench: timer = bench.run() return timer def bench_torch_mm(ctx: BenchmarkContext, arg_pool_size: int, op_type: OpType, cuda_graph_nops: Optional[int] = None) -> TMeasurement: """ Benchmark basic torch.mm as a roofline. When all the input tokens have the same LoRA ID, the LoRA kernels are just a matmul. This torch.mm benchmark serves as a roofline for that case. input op_type is used in determining the m, k, n dimensions for the matmul. """ batch_size, hidden_size, lora_rank, seq_length, dtype = (ctx.batch_size, ctx.hidden_size, ctx.lora_rank, ctx.seq_length, ctx.dtype) m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank) # For a fairer comparison. n = n * ctx.num_slices # Get matmul input and output tensors for A x B = C As, Bs, Cs = [], [], [] for _ in range(arg_pool_size): As.append(torch.rand((m, k), dtype=dtype).to("cuda")) Bs.append(torch.rand((n, k), dtype=dtype).to("cuda").t()) Cs.append(torch.rand((m, n), dtype=dtype).to("cuda")) # Make torch.mm kwargs mm_kwargs = {'input': ArgPool(As), 'mat2': ArgPool(Bs), 'out': ArgPool(Cs)} description = ( f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}" f"x{dtype_to_str(dtype)}" f"=>{dtype_to_str(dtype)})") cuda_graph_params = None if cuda_graph_nops: cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) with Bench(cuda_graph_params, ctx.bench_label(), ctx.bench_sublabel(op_type), description, torch.mm, **mm_kwargs) as bench: return bench.run() # runner def use_cuda_graph_recommendation() -> str: return """ Triton kernels have a significant launch overhead with launched directly via python. This overhead is more noticeable for small the problem sizes. For these cases, it is recommended to use the script with `--cuda-graph-nops N` to benchmark N consecutive invocations of the benchmarking operations from inside a CUDA Graph. Note that the returned measurement is for N invocations of the operation. """ def print_timers(timers: List[TMeasurement], args: Optional[argparse.Namespace] = None): compare = TBenchmark.Compare(timers) compare.print() if args and args.cuda_graph_nops: print( f"Note : The timings reported above is for {args.cuda_graph_nops} " "consecutive invocations of the benchmarking functions. " f"Please divide by {args.cuda_graph_nops} for single invocation " "timings.") print("Note on Comparison with torch.mm : The torch.mm numbers are " "benchmark numbers of a simple matmul emulating the single lora " "case. It is provided as a roofline for comparing our LoRA Kernel " "implementations. It is expected that the LoRA kernels will be " "slower than torch.mm in cases where num_loras is big. But for " "small num_loras the goal should be to match the torch.mm numbers.") def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): if args.cuda_graph_nops is not None: assert args.cuda_graph_nops > 0 print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA " "Graph") else: print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}") timers = [] for bench_ctx in bench_ctxs: for seq_len in args.seq_lengths: bench_ops: List[OpType] = [] if seq_len == 1: # bench all decode ops bench_ops = [op for op in args.op_types if op.is_decode_op()] else: # bench all prefill ops bench_ops = [op for op in args.op_types if op.is_prefill_op()] seq_len_timers = [] for bench_op in bench_ops: for num_slices in bench_op.num_slices(): _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices( num_slices) # Benchmark torch.mm as a roofline seq_len_timers.append( bench_torch_mm(_ctx, args.arg_pool_size, bench_op, args.cuda_graph_nops)) # Benchmark bench_op expand_fn_add_inputs = [ None ] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs for add_input_arg in expand_fn_add_inputs: seq_len_timers.append( bench_optype(_ctx, args.arg_pool_size, bench_op, args.cuda_graph_nops, add_input_arg, args.test_correctness)) print_timers(seq_len_timers) timers.extend(seq_len_timers) # Result stdout dump print("== All Results ====") print_timers(timers, args) if args.output_directory: # Result file dump od = Path(args.output_directory) if not od.exists(): od.mkdir() timestamp = int(time.time()) pkl_file = od / f"lora_bench-{timestamp}.pkl" print(f"Writing benchmarks to {pkl_file}") with open(pkl_file, "wb") as f: pickle.dump(timers, f) def as_benchmark_contexts(hidden_sizes: List[int], lora_ranks: List[int], args: argparse.Namespace) -> List[BenchmarkContext]: ctxs: List[BenchmarkContext] = [] for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras, args.sort_by_lora_id): ctxs.append( BenchmarkContext( batch_size=batch_size, hidden_size=hidden_size, lora_rank=lora_rank, num_loras=num_loras, num_active_loras=args.num_active_loras if args.num_active_loras else num_loras, # To be filled based on the OpType to benchmark seq_length=None, sort_by_lora_id=sort_by_lora_id, dtype=args.dtype, # To be filled based on the OpType to benchmark num_slices=None)) return ctxs def run_list_bench(args: argparse.Namespace): print(args) print("List bench :\n" f" Hidden Sizes {args.hidden_sizes}" f" LoRA Ranks {args.lora_ranks}") # Get all benchmarking contexts bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args) run(args, bench_contexts) def run_range_bench(args: argparse.Namespace): print(args) hidden_sizes = list( range(args.hidden_sizes_start, args.hidden_sizes_end + 1, args.hidden_sizes_increment)) lora_ranks = list( range(args.lora_ranks_start, args.lora_ranks_end + 1, args.lora_ranks_increment)) print("Range bench :\n" f" Hidden Sizes {hidden_sizes}" f" LoRA Ranks {lora_ranks}") # Get all benchmarking contexts bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args) run(args, bench_contexts) def run_model_bench(args: argparse.Namespace): print(args) def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]: hidden_sizes = set() for KN, tp_split_dim in WEIGHT_SHAPES[model]: KN[tp_split_dim] = KN[tp_split_dim] // tp_size hidden_sizes.add(KN[1]) return hidden_sizes # Get all hidden sizes hidden_sizes: set[int] = set() for model_name, tp_size in product(args.models, args.tp_sizes): hidden_sizes = hidden_sizes.union( hidden_sizes_from_model(model_name, tp_size)) print("Model bench :\n" f" Hidden Sizes {hidden_sizes}" f" LoRA Ranks {args.lora_ranks}") # Get all benchmarking contexts bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args) run(args, bench_contexts) if __name__ == '__main__': def to_torch_dtype(dt): if dt == "torch.float16": return torch.float16 if dt == "torch.bfloat16": return torch.bfloat16 raise ValueError("unsupported dtype") def get_bool(s: str) -> bool: return s.lower() in ['true', '1'] def add_common_command_args(p: argparse.ArgumentParser): p.add_argument( "--dtype", type=to_torch_dtype, required=True, help="Available options are ['torch.float16', 'torch.bfloat16']") p.add_argument( "--arg-pool-size", type=int, default=32, help="Run profiles with a pool of input/output/meta tensors instead" "of simply reusing the same tensors for all runs. A bigger arg-pool" "mitigates hardware caching effects during benchmarking.") p.add_argument( "--cuda-graph-nops", type=int, help=("when set profiling is done using cudagraph, " "with the given number of operations in a graph." "Note that the measurement returned is the time " "taken for N consecutive executions of the benchmarking " "functions, where N is the value of this argument.")) p.add_argument("--num-loras", nargs="+", type=int, default=DEFAULT_NUM_LORAS) p.add_argument("--num-active-loras", type=int, default=None, help="Active LoRAs. When None, all LoRAs are active") p.add_argument("--sort-by-lora-id", nargs="+", type=get_bool, default=DEFAULT_SORT_BY_LORA_IDS) p.add_argument("--op-types", nargs="+", type=OpType.from_str, default=list(OpType)) p.add_argument('--seq-lengths', nargs="+", type=int, default=DEFAULT_SEQ_LENGTHS) p.add_argument("--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES) p.add_argument("--expand-fn-add-inputs", nargs="+", type=get_bool, default=DEFAULT_EXPAND_FN_ADD_INPUTS) p.add_argument( '-o', '--output-directory', type=str, help=("Output directory to store a the list of benchmarking" "TMeasurement objects as a pickle file")) p.add_argument( "--test-correctness", action='store_true', help=("When enabled, the benchmarking functions are tested" "for correctness before the actual benchmarking")) parser = FlexibleArgumentParser( description=f""" Benchmark LoRA kernels: {use_cuda_graph_recommendation()} list_bench example: python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 model_bench example: python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 range_bench example: python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 """, # noqa: E501 formatter_class=argparse.RawTextHelpFormatter) subparsers = parser.add_subparsers(dest="cmd", required=True) list_parser = subparsers.add_parser("list_bench") list_parser.add_argument("--hidden-sizes", nargs="+", type=int, default=DEFAULT_HIDDEN_SIZES) list_parser.add_argument("--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS) add_common_command_args(list_parser) list_parser.set_defaults(func=run_list_bench) range_parser = subparsers.add_parser("range_bench") range_parser.add_argument("--hidden-sizes-start", type=int, required=True) range_parser.add_argument("--hidden-sizes-end", type=int, required=True) range_parser.add_argument("--hidden-sizes-increment", type=int, required=True) range_parser.add_argument("--lora-ranks-start", type=int, required=True) range_parser.add_argument("--lora-ranks-end", type=int, required=True) range_parser.add_argument("--lora-ranks-increment", type=int, required=True) add_common_command_args(range_parser) range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") model_parser.add_argument("--models", nargs="+", type=str, default=DEFAULT_MODELS, choices=WEIGHT_SHAPES.keys()) model_parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) model_parser.add_argument("--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS) add_common_command_args(model_parser) model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() args.func(args)