Fix w8a8 benchmark and add Llama-3-8B (#5562)
This commit is contained in:
parent
845a3f26f9
commit
e2b85cf86a
@ -46,7 +46,7 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
|||||||
# impl
|
# impl
|
||||||
|
|
||||||
|
|
||||||
def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||||
scale_b: torch.tensor,
|
scale_b: torch.tensor,
|
||||||
out_dtype: torch.dtype) -> torch.tensor:
|
out_dtype: torch.dtype) -> torch.tensor:
|
||||||
return torch.mm(a, b)
|
return torch.mm(a, b)
|
||||||
@ -115,7 +115,7 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
|||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
||||||
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
||||||
torch.bfloat16, label, sub_label, pytorch_i8_impl,
|
torch.bfloat16, label, sub_label, pytorch_mm_impl,
|
||||||
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
||||||
|
|
||||||
# cutlass impl
|
# cutlass impl
|
||||||
@ -136,6 +136,13 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
|||||||
|
|
||||||
timers = []
|
timers = []
|
||||||
|
|
||||||
|
# pytorch impl w. bf16
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
||||||
|
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
||||||
|
torch.bfloat16, label, sub_label, pytorch_mm_impl,
|
||||||
|
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
||||||
|
|
||||||
# pytorch impl: bf16 output, without fp8 fast accum
|
# pytorch impl: bf16 output, without fp8 fast accum
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||||
@ -160,14 +167,12 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
|||||||
|
|
||||||
# cutlass impl: bf16 output
|
# cutlass impl: bf16 output
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||||
torch.bfloat16, label, sub_label, cutlass_impl,
|
cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm"))
|
||||||
"cutlass_fp8_fp8_bf16_scaled_mm"))
|
|
||||||
# cutlass impl: fp16 output
|
# cutlass impl: fp16 output
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||||
torch.float16, label, sub_label, cutlass_impl,
|
cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm"))
|
||||||
"cutlass_fp8_fp8_fp16_scaled_mm"))
|
|
||||||
return timers
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,6 +22,12 @@ WEIGHT_SHAPES = {
|
|||||||
([4096, 22016], 1),
|
([4096, 22016], 1),
|
||||||
([11008, 4096], 0),
|
([11008, 4096], 0),
|
||||||
],
|
],
|
||||||
|
"meta-llama/Llama-3-8b": [
|
||||||
|
([4096, 6144], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 28672], 1),
|
||||||
|
([14336, 4096], 0),
|
||||||
|
],
|
||||||
"meta-llama/Llama-2-13b-hf": [
|
"meta-llama/Llama-2-13b-hf": [
|
||||||
([5120, 15360], 1),
|
([5120, 15360], 1),
|
||||||
([5120, 5120], 0),
|
([5120, 5120], 0),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user