[Bugfix] Correctly call cudaProfilerStop
in benchmarks script (#14183)
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
parent
ad60bbb2b2
commit
c34eeec58d
@ -40,7 +40,7 @@ def main(num_tokens: int,
|
|||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
if profile:
|
if profile:
|
||||||
torch.cuda.cudart().cudaProfilerStart()
|
torch.cuda.cudart().cudaProfilerStop()
|
||||||
return (end_time - start_time) / num_iters
|
return (end_time - start_time) / num_iters
|
||||||
|
|
||||||
# Warmup.
|
# Warmup.
|
||||||
|
@ -153,7 +153,6 @@ def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
|
|||||||
result = torch.nn.functional.linear(x, w)
|
result = torch.nn.functional.linear(x, w)
|
||||||
result *= scaling
|
result *= scaling
|
||||||
out_list.append(result)
|
out_list.append(result)
|
||||||
torch.cat(out_list, dim=0)
|
|
||||||
|
|
||||||
cat_result = torch.cat(out_list, dim=0)
|
cat_result = torch.cat(out_list, dim=0)
|
||||||
|
|
||||||
|
@ -45,7 +45,6 @@ def terse_type_name(dt):
|
|||||||
torch.float16: "fp16",
|
torch.float16: "fp16",
|
||||||
torch.int8: "int8",
|
torch.int8: "int8",
|
||||||
torch.float8_e4m3fn: "fp8",
|
torch.float8_e4m3fn: "fp8",
|
||||||
torch.bfloat16: "bf16",
|
|
||||||
torch.float: "float",
|
torch.float: "float",
|
||||||
torch.int: "int",
|
torch.int: "int",
|
||||||
}[dt]
|
}[dt]
|
||||||
@ -259,7 +258,7 @@ def machete_create_bench_fn(bt: BenchmarkTensors,
|
|||||||
|
|
||||||
return lambda: ops.machete_mm(
|
return lambda: ops.machete_mm(
|
||||||
a=bt.a,
|
a=bt.a,
|
||||||
b_q=bt.w_q,
|
b_q=w_q,
|
||||||
b_type=bt.wtype,
|
b_type=bt.wtype,
|
||||||
b_group_scales=bt.w_g_s,
|
b_group_scales=bt.w_g_s,
|
||||||
b_group_zeros=w_g_zp,
|
b_group_zeros=w_g_zp,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -176,7 +176,7 @@ def main(
|
|||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
if profile:
|
if profile:
|
||||||
torch.cuda.cudart().cudaProfilerStart()
|
torch.cuda.cudart().cudaProfilerStop()
|
||||||
return (end_time - start_time) / num_iters
|
return (end_time - start_time) / num_iters
|
||||||
|
|
||||||
# Warmup.
|
# Warmup.
|
||||||
|
@ -40,7 +40,7 @@ def main(num_tokens: int,
|
|||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
if profile:
|
if profile:
|
||||||
torch.cuda.cudart().cudaProfilerStart()
|
torch.cuda.cudart().cudaProfilerStop()
|
||||||
return (end_time - start_time) / num_iters
|
return (end_time - start_time) / num_iters
|
||||||
|
|
||||||
# Warmup.
|
# Warmup.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user