[misc] Add LoRA kernel micro benchmarks (#11579)
This commit is contained in:
parent
874f7c292a
commit
5fd24ec02e
1147
benchmarks/kernels/benchmark_lora.py
Normal file
1147
benchmarks/kernels/benchmark_lora.py
Normal file
File diff suppressed because it is too large
Load Diff
210
benchmarks/kernels/utils.py
Normal file
210
benchmarks/kernels/utils.py
Normal file
@ -0,0 +1,210 @@
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Iterable, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CudaGraphBenchParams:
|
||||
num_ops_in_cuda_graph: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ArgPool:
|
||||
"""
|
||||
When some argument of the benchmarking function is annotated with this type,
|
||||
the benchmarking class (BenchMM) will collapse the argument to a pick a
|
||||
single value from the given list of values, during function invocation.
|
||||
For every invocation during a benchmarking run, it will choose a
|
||||
different value from the list.
|
||||
"""
|
||||
values: Iterable[Any]
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.values[index]
|
||||
|
||||
|
||||
class Bench:
|
||||
|
||||
class ArgsIterator:
|
||||
|
||||
def __init__(self, args_list, kwargs_list):
|
||||
assert len(args_list) == len(kwargs_list)
|
||||
self.args_list = args_list
|
||||
self.kwargs_list = kwargs_list
|
||||
self.n = len(self.args_list)
|
||||
self.idx = 0
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
yield (self.args_list[self.idx], self.kwargs_list[self.idx])
|
||||
self.idx += 1
|
||||
self.idx = self.idx % self.n
|
||||
|
||||
def reset(self):
|
||||
self.idx = 0
|
||||
|
||||
@property
|
||||
def n_args(self):
|
||||
return self.n
|
||||
|
||||
def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams],
|
||||
label: str, sub_label: str, description: str, fn: Callable,
|
||||
*args, **kwargs):
|
||||
|
||||
self.cuda_graph_params = cuda_graph_params
|
||||
self.use_cuda_graph = self.cuda_graph_params is not None
|
||||
self.label = label
|
||||
self.sub_label = sub_label
|
||||
self.description = description
|
||||
self.fn = fn
|
||||
|
||||
# Process args
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
self.args_list, self.kwargs_list = self.collapse_argpool(
|
||||
*args, **kwargs)
|
||||
self.args_iterator = self.ArgsIterator(self.args_list,
|
||||
self.kwargs_list)
|
||||
|
||||
# Cudagraph runner
|
||||
self.g = None
|
||||
if self.use_cuda_graph:
|
||||
self.g = self.get_cuda_graph_runner()
|
||||
|
||||
# benchmark run params
|
||||
self.min_run_time = 1
|
||||
|
||||
def collapse_argpool(self, *args, **kwargs):
|
||||
argpool_args = [arg for arg in args if isinstance(arg, ArgPool)] + [
|
||||
arg for arg in kwargs.values() if isinstance(arg, ArgPool)
|
||||
]
|
||||
if len(argpool_args) == 0:
|
||||
return [args], [kwargs]
|
||||
|
||||
# Make sure all argpools are of the same size
|
||||
argpool_size = len(argpool_args[0].values)
|
||||
assert all([argpool_size == len(arg.values) for arg in argpool_args])
|
||||
|
||||
# create copies of the args
|
||||
args_list = []
|
||||
kwargs_list = []
|
||||
for _ in range(argpool_size):
|
||||
args_list.append(args)
|
||||
kwargs_list.append(kwargs.copy())
|
||||
|
||||
for i in range(argpool_size):
|
||||
# collapse args; Just pick the ith value
|
||||
args_list[i] = tuple([
|
||||
arg[i] if isinstance(arg, ArgPool) else arg
|
||||
for arg in args_list[i]
|
||||
])
|
||||
|
||||
# collapse kwargs
|
||||
kwargs_i = kwargs_list[i]
|
||||
arg_pool_keys = [
|
||||
k for k, v in kwargs_i.items() if isinstance(v, ArgPool)
|
||||
]
|
||||
for k in arg_pool_keys:
|
||||
# again just pick the ith value
|
||||
kwargs_i[k] = kwargs_i[k][i]
|
||||
kwargs_list[i] = kwargs_i
|
||||
|
||||
return args_list, kwargs_list
|
||||
|
||||
def get_cuda_graph_runner(self):
|
||||
assert self.use_cuda_graph
|
||||
assert self.args_iterator is not None
|
||||
|
||||
num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph
|
||||
|
||||
# warmup
|
||||
args_it = self.args_iterator.__next__()
|
||||
for _ in range(2):
|
||||
args, kwargs = next(args_it)
|
||||
self.fn(*args, **kwargs)
|
||||
|
||||
self.args_iterator.reset()
|
||||
args_it = self.args_iterator.__next__()
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
for _ in range(num_graph_ops):
|
||||
args, kwargs = next(args_it)
|
||||
self.fn(*args, **kwargs)
|
||||
return g
|
||||
|
||||
def run_cudagrah(self) -> TMeasurement:
|
||||
assert self.use_cuda_graph
|
||||
globals = {'g': self.g}
|
||||
|
||||
return TBenchmark.Timer(
|
||||
stmt="g.replay()",
|
||||
globals=globals,
|
||||
label=(
|
||||
f"{self.label}"
|
||||
f" | cugraph {self.cuda_graph_params.num_ops_in_cuda_graph} ops"
|
||||
),
|
||||
sub_label=self.sub_label,
|
||||
description=self.description,
|
||||
).blocked_autorange(min_run_time=self.min_run_time)
|
||||
|
||||
def run_eager(self) -> TMeasurement:
|
||||
setup = None
|
||||
stmt = None
|
||||
globals = None
|
||||
|
||||
has_arg_pool = self.args_iterator.n_args > 1
|
||||
if has_arg_pool:
|
||||
setup = '''
|
||||
args_iterator.reset()
|
||||
args_it = args_iterator.__next__()
|
||||
'''
|
||||
stmt = '''
|
||||
args, kwargs = next(args_it)
|
||||
fn(*args, **kwargs)
|
||||
'''
|
||||
globals = {'fn': self.fn, 'args_iterator': self.args_iterator}
|
||||
else:
|
||||
# no arg pool. Just use the args and kwargs directly
|
||||
self.args_iterator.reset()
|
||||
args_it = self.args_iterator.__next__()
|
||||
args, kwargs = next(args_it)
|
||||
|
||||
setup = ""
|
||||
stmt = '''
|
||||
fn(*args, **kwargs)
|
||||
'''
|
||||
globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs}
|
||||
|
||||
return TBenchmark.Timer(
|
||||
stmt=stmt,
|
||||
setup=setup,
|
||||
globals=globals,
|
||||
label=self.label,
|
||||
sub_label=self.sub_label,
|
||||
description=self.description,
|
||||
).blocked_autorange(min_run_time=self.min_run_time)
|
||||
|
||||
def run(self) -> TMeasurement:
|
||||
timer = None
|
||||
if self.use_cuda_graph: # noqa SIM108
|
||||
timer = self.run_cudagrah()
|
||||
else:
|
||||
timer = self.run_eager()
|
||||
if not timer.meets_confidence() or timer.has_warnings:
|
||||
print("Doesn't meet confidence - re-running bench ...")
|
||||
return self.run()
|
||||
return timer
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if exc_type:
|
||||
print(f"exc type {exc_type}")
|
||||
print(f"exc value {exc_value}")
|
||||
print(f"exc traceback {traceback}")
|
Loading…
x
Reference in New Issue
Block a user