[Bugfix] Fix w8a8 benchmarks for int8 case (#5643)
This commit is contained in:
parent
b23ce92032
commit
6820724e51
@ -120,9 +120,8 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
|
||||
# cutlass impl
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
||||
torch.bfloat16, label, sub_label, cutlass_impl,
|
||||
"cutlass_i8_i8_bf16_scaled_mm"))
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm"))
|
||||
|
||||
return timers
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user