[Bugfix] Fix w8a8 benchmarks for int8 case (#5643)
This commit is contained in:
parent
b23ce92032
commit
6820724e51
@ -120,9 +120,8 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
|||||||
|
|
||||||
# cutlass impl
|
# cutlass impl
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||||
torch.bfloat16, label, sub_label, cutlass_impl,
|
cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm"))
|
||||||
"cutlass_i8_i8_bf16_scaled_mm"))
|
|
||||||
|
|
||||||
return timers
|
return timers
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user