[Kernel] Update cutlass_scaled_mm
to support 2d group (blockwise) scaling (#11868)
This commit is contained in:
parent
4078052f09
commit
9798b2fb00
@ -245,7 +245,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
cutlass
|
cutlass
|
||||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||||
GIT_TAG v3.6.0
|
GIT_TAG v3.7.0
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
|
|
||||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||||
@ -299,7 +299,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
||||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
set(SRCS
|
||||||
|
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
||||||
|
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||||
|
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
|
||||||
|
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||||
|
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||||
|
@ -3,7 +3,7 @@ import copy
|
|||||||
import itertools
|
import itertools
|
||||||
import pickle as pkl
|
import pickle as pkl
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Iterable, List, Tuple
|
from typing import Callable, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.benchmark as TBenchmark
|
import torch.utils.benchmark as TBenchmark
|
||||||
@ -12,6 +12,8 @@ from utils import make_rand_tensors
|
|||||||
from weight_shapes import WEIGHT_SHAPES
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
w8a8_block_fp8_matmul)
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||||
@ -38,8 +40,15 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
|||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
|
|
||||||
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
def bench_int8(
|
||||||
sub_label: str) -> Iterable[TMeasurement]:
|
dtype: torch.dtype,
|
||||||
|
m: int,
|
||||||
|
k: int,
|
||||||
|
n: int,
|
||||||
|
label: str,
|
||||||
|
sub_label: str,
|
||||||
|
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||||
|
"""Benchmark INT8-based kernels."""
|
||||||
assert dtype == torch.int8
|
assert dtype == torch.int8
|
||||||
a, b = make_rand_tensors(torch.int8, m, n, k)
|
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
@ -48,155 +57,132 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
|||||||
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
|
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
|
||||||
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
|
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
|
||||||
|
|
||||||
|
bench_fns = {
|
||||||
|
"pytorch_bf16_bf16_bf16_matmul-no-scales":
|
||||||
|
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||||
|
),
|
||||||
|
"pytorch_fp16_fp16_fp16_matmul-no-scales":
|
||||||
|
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm":
|
||||||
|
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm_bias":
|
||||||
|
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
|
||||||
|
bias),
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm_azp":
|
||||||
|
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||||
|
bfloat16, azp_adj),
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm_azp_bias":
|
||||||
|
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||||
|
bfloat16, azp_adj, None, bias),
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm_azp_pt":
|
||||||
|
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||||
|
bfloat16, azp_adj, azp),
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias":
|
||||||
|
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||||
|
bfloat16, azp_adj, azp, bias),
|
||||||
|
}
|
||||||
|
|
||||||
timers = []
|
timers = []
|
||||||
# pytorch impl - bfloat16
|
for name, fn in bench_fns.items():
|
||||||
timers.append(
|
# If bench_kernels is None, run all. Otherwise, run only exact matches.
|
||||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
if bench_kernels is None or name in bench_kernels:
|
||||||
torch.mm, a.to(dtype=torch.bfloat16),
|
print(f"Running {name}")
|
||||||
b.to(dtype=torch.bfloat16)))
|
timers.append(bench_fn(label, sub_label, name, fn))
|
||||||
|
|
||||||
# pytorch impl - float16
|
|
||||||
timers.append(
|
|
||||||
bench_fn(label, sub_label,
|
|
||||||
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
|
|
||||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
|
|
||||||
|
|
||||||
# cutlass impl
|
|
||||||
timers.append(
|
|
||||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
|
|
||||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
|
||||||
torch.bfloat16))
|
|
||||||
|
|
||||||
# cutlass with bias
|
|
||||||
timers.append(
|
|
||||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
|
|
||||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
|
||||||
bias))
|
|
||||||
|
|
||||||
# cutlass with azp per-tensor
|
|
||||||
timers.append(
|
|
||||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp",
|
|
||||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
|
||||||
torch.bfloat16, azp_adj))
|
|
||||||
|
|
||||||
# cutlass with azp per-tensor + bias
|
|
||||||
timers.append(
|
|
||||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias",
|
|
||||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
|
||||||
torch.bfloat16, azp_adj, None, bias))
|
|
||||||
|
|
||||||
# cutlass with azp per-token
|
|
||||||
timers.append(
|
|
||||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt",
|
|
||||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
|
||||||
torch.bfloat16, azp_adj, azp))
|
|
||||||
|
|
||||||
# cutlass with azp per-token + bias
|
|
||||||
timers.append(
|
|
||||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias",
|
|
||||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
|
||||||
torch.bfloat16, azp_adj, azp, bias))
|
|
||||||
|
|
||||||
return timers
|
return timers
|
||||||
|
|
||||||
|
|
||||||
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
def bench_fp8(
|
||||||
sub_label: str) -> Iterable[TMeasurement]:
|
dtype: torch.dtype,
|
||||||
|
m: int,
|
||||||
|
k: int,
|
||||||
|
n: int,
|
||||||
|
label: str,
|
||||||
|
sub_label: str,
|
||||||
|
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||||
|
"""Benchmark FP8-based kernels."""
|
||||||
assert dtype == torch.float8_e4m3fn
|
assert dtype == torch.float8_e4m3fn
|
||||||
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||||
|
a_cont = a.contiguous()
|
||||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
block_scale_a = torch.rand((m, k // 128),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float32)
|
||||||
|
block_scale_b = torch.rand((k // 128, n // 128),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float32)
|
||||||
|
block_scale_a_M_major = block_scale_a.t().contiguous().t()
|
||||||
|
block_scale_b_K_major = block_scale_b.t().contiguous().t()
|
||||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
timers = []
|
print(m, k, n)
|
||||||
|
|
||||||
# pytorch impl w. bf16
|
bench_fns = {
|
||||||
timers.append(
|
"pytorch_bf16_bf16_bf16_matmul-no-scales":
|
||||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||||
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
|
),
|
||||||
b.to(dtype=torch.bfloat16, device="cuda")))
|
"pytorch_fp16_fp16_fp16_matmul-no-scales":
|
||||||
|
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
|
||||||
# pytorch impl: bf16 output, without fp8 fast accum
|
"pytorch_fp8_fp8_fp16_scaled_mm":
|
||||||
timers.append(
|
lambda: torch._scaled_mm(
|
||||||
bench_fn(label,
|
a, b, scale_a, scale_b, out_dtype=torch.float16),
|
||||||
sub_label,
|
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum":
|
||||||
"pytorch_fp8_fp8_bf16_scaled_mm",
|
lambda: torch._scaled_mm(a,
|
||||||
torch._scaled_mm,
|
|
||||||
a,
|
|
||||||
b,
|
b,
|
||||||
scale_a=scale_a,
|
scale_a,
|
||||||
scale_b=scale_b,
|
scale_b,
|
||||||
out_dtype=torch.bfloat16))
|
|
||||||
|
|
||||||
# pytorch impl: bf16 output, with fp8 fast accum
|
|
||||||
timers.append(
|
|
||||||
bench_fn(label,
|
|
||||||
sub_label,
|
|
||||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
|
||||||
torch._scaled_mm,
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
scale_a=scale_a,
|
|
||||||
scale_b=scale_b,
|
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
use_fast_accum=True))
|
|
||||||
|
|
||||||
# pytorch impl: fp16 output, without fp8 fast accum
|
|
||||||
timers.append(
|
|
||||||
bench_fn(label,
|
|
||||||
sub_label,
|
|
||||||
"pytorch_fp8_fp8_fp16_scaled_mm",
|
|
||||||
torch._scaled_mm,
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
scale_a=scale_a,
|
|
||||||
scale_b=scale_b,
|
|
||||||
out_dtype=torch.float16))
|
|
||||||
|
|
||||||
# pytorch impl: fp16 output, with fp8 fast accum
|
|
||||||
timers.append(
|
|
||||||
bench_fn(label,
|
|
||||||
sub_label,
|
|
||||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
|
||||||
torch._scaled_mm,
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
scale_a=scale_a,
|
|
||||||
scale_b=scale_b,
|
|
||||||
out_dtype=torch.float16,
|
out_dtype=torch.float16,
|
||||||
use_fast_accum=True))
|
use_fast_accum=True),
|
||||||
|
"pytorch_fp8_fp8_bf16_scaled_mm":
|
||||||
|
lambda: torch._scaled_mm(
|
||||||
|
a, b, scale_a, scale_b, out_dtype=torch.bfloat16),
|
||||||
|
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum":
|
||||||
|
lambda: torch._scaled_mm(a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
use_fast_accum=True),
|
||||||
|
"cutlass_fp8_fp8_bf16_scaled_mm":
|
||||||
|
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
|
||||||
|
"cutlass_fp8_fp8_fp16_scaled_mm":
|
||||||
|
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16),
|
||||||
|
"cutlass_fp8_fp8_bf16_scaled_mm_bias":
|
||||||
|
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
|
||||||
|
bias),
|
||||||
|
"cutlass_fp8_fp8_fp16_scaled_mm_bias":
|
||||||
|
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16,
|
||||||
|
bias.to(dtype=torch.float16)),
|
||||||
|
"triton_fp8_fp8_fp16_scaled_mm_blockwise":
|
||||||
|
lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a,
|
||||||
|
block_scale_b.t(), (128, 128)),
|
||||||
|
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise":
|
||||||
|
lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major,
|
||||||
|
block_scale_b_K_major, torch.float16),
|
||||||
|
}
|
||||||
|
|
||||||
# cutlass impl: bf16 output
|
timers = []
|
||||||
timers.append(
|
for name, fn in bench_fns.items():
|
||||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
|
# If bench_kernels is None, run all. Otherwise, run only exact matches.
|
||||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
if bench_kernels is None or name in bench_kernels:
|
||||||
torch.bfloat16))
|
print(f"Running {name}")
|
||||||
# cutlass impl: fp16 output
|
timers.append(bench_fn(label, sub_label, name, fn))
|
||||||
timers.append(
|
|
||||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm",
|
|
||||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16))
|
|
||||||
|
|
||||||
# cutlass impl: bf16 output, with bias
|
|
||||||
timers.append(
|
|
||||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias",
|
|
||||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
|
||||||
bias))
|
|
||||||
|
|
||||||
# cutlass impl: fp16 output, with bias
|
|
||||||
timers.append(
|
|
||||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias",
|
|
||||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16,
|
|
||||||
bias.to(dtype=torch.float16)))
|
|
||||||
|
|
||||||
return timers
|
return timers
|
||||||
|
|
||||||
|
|
||||||
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
def bench(dtype: torch.dtype,
|
||||||
sub_label: str) -> Iterable[TMeasurement]:
|
m: int,
|
||||||
|
k: int,
|
||||||
|
n: int,
|
||||||
|
label: str,
|
||||||
|
sub_label: str,
|
||||||
|
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||||
if dtype == torch.int8:
|
if dtype == torch.int8:
|
||||||
return bench_int8(dtype, m, k, n, label, sub_label)
|
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
|
||||||
if dtype == torch.float8_e4m3fn:
|
if dtype == torch.float8_e4m3fn:
|
||||||
return bench_fp8(dtype, m, k, n, label, sub_label)
|
return bench_fp8(dtype, m, k, n, label, sub_label, bench_kernels)
|
||||||
raise ValueError("unsupported type")
|
raise ValueError("unsupported type")
|
||||||
|
|
||||||
|
|
||||||
@ -207,18 +193,22 @@ def print_timers(timers: Iterable[TMeasurement]):
|
|||||||
|
|
||||||
|
|
||||||
def run(dtype: torch.dtype,
|
def run(dtype: torch.dtype,
|
||||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
MKNs: Iterable[Tuple[int, int, int]],
|
||||||
|
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||||
results = []
|
results = []
|
||||||
for m, k, n in MKNs:
|
for m, k, n in MKNs:
|
||||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
timers = bench(dtype,
|
||||||
f"MKN=({m}x{k}x{n})")
|
m,
|
||||||
|
k,
|
||||||
|
n,
|
||||||
|
f"scaled-{dtype}-gemm",
|
||||||
|
f"MKN=({m}x{k}x{n})",
|
||||||
|
bench_kernels=bench_kernels)
|
||||||
print_timers(timers)
|
print_timers(timers)
|
||||||
results.extend(timers)
|
results.extend(timers)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
# output makers
|
|
||||||
def make_output(data: Iterable[TMeasurement],
|
def make_output(data: Iterable[TMeasurement],
|
||||||
MKNs: Iterable[Tuple[int, int, int]],
|
MKNs: Iterable[Tuple[int, int, int]],
|
||||||
base_description: str,
|
base_description: str,
|
||||||
@ -232,15 +222,11 @@ def make_output(data: Iterable[TMeasurement],
|
|||||||
pkl.dump(data, f)
|
pkl.dump(data, f)
|
||||||
|
|
||||||
|
|
||||||
# argparse runners
|
|
||||||
|
|
||||||
|
|
||||||
def run_square_bench(args):
|
def run_square_bench(args):
|
||||||
dim_sizes = list(
|
dim_sizes = list(
|
||||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||||
data = run(args.dtype, MKNs)
|
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||||
|
|
||||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
@ -251,8 +237,7 @@ def run_range_bench(args):
|
|||||||
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||||
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||||
MKNs = list(zip(Ms, Ks, Ns))
|
MKNs = list(zip(Ms, Ks, Ns))
|
||||||
data = run(args.dtype, MKNs)
|
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||||
|
|
||||||
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
@ -278,7 +263,7 @@ def run_model_bench(args):
|
|||||||
for k, n in KNs:
|
for k, n in KNs:
|
||||||
MKNs.append((m, k, n))
|
MKNs.append((m, k, n))
|
||||||
|
|
||||||
data = run(args.dtype, MKNs)
|
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||||
model_bench_data.append(data)
|
model_bench_data.append(data)
|
||||||
|
|
||||||
# Print all results
|
# Print all results
|
||||||
@ -328,6 +313,15 @@ Benchmark Cutlass GEMM.
|
|||||||
type=to_torch_dtype,
|
type=to_torch_dtype,
|
||||||
required=True,
|
required=True,
|
||||||
help="Available options are ['int8', 'fp8']")
|
help="Available options are ['int8', 'fp8']")
|
||||||
|
parser.add_argument(
|
||||||
|
"--kernels",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
"Exact names of the kernels to benchmark. If not set, runs all kernels."
|
||||||
|
)
|
||||||
|
|
||||||
subparsers = parser.add_subparsers(dest="cmd")
|
subparsers = parser.add_subparsers(dest="cmd")
|
||||||
|
|
||||||
square_parser = subparsers.add_parser("square_bench")
|
square_parser = subparsers.add_parser("square_bench")
|
||||||
|
@ -1,7 +1,14 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
#include <climits>
|
#include <climits>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
inline uint32_t next_pow_2(uint32_t const num) {
|
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||||
if (num <= 1) return num;
|
if (num <= 1) return num;
|
||||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline constexpr std::enable_if_t<std::is_integral_v<T>, T> ceil_div(T a, T b) {
|
||||||
|
return (a + b - 1) / b;
|
||||||
|
}
|
@ -32,3 +32,20 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int32_t get_sm_version_num();
|
int32_t get_sm_version_num();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A wrapper for a kernel that is used to guard against compilation on
|
||||||
|
* architectures that will never use the kernel. The purpose of this is to
|
||||||
|
* reduce the size of the compiled binary.
|
||||||
|
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||||
|
* into code that will be executed on the device where it is defined.
|
||||||
|
*/
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm90_or_later : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||||
|
Kernel::operator()(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
123
csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
Normal file
123
csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
|
||||||
|
// clang-format off
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl"
|
||||||
|
|
||||||
|
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass::gemm::collective {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// GMMA_TMA_WS_SS (BlockScaled Builders)
|
||||||
|
template <
|
||||||
|
class ElementA,
|
||||||
|
class GmemLayoutATag,
|
||||||
|
int AlignmentA,
|
||||||
|
class ElementB,
|
||||||
|
class GmemLayoutBTag,
|
||||||
|
int AlignmentB,
|
||||||
|
class ElementAccumulator,
|
||||||
|
class TileShape_MNK,
|
||||||
|
class ClusterShape_MNK,
|
||||||
|
class StageCountType,
|
||||||
|
int ScaleGranularityM
|
||||||
|
>
|
||||||
|
struct CollectiveBuilder<
|
||||||
|
arch::Sm90,
|
||||||
|
arch::OpClassTensorOp,
|
||||||
|
ElementA,
|
||||||
|
GmemLayoutATag,
|
||||||
|
AlignmentA,
|
||||||
|
ElementB,
|
||||||
|
GmemLayoutBTag,
|
||||||
|
AlignmentB,
|
||||||
|
ElementAccumulator,
|
||||||
|
TileShape_MNK,
|
||||||
|
ClusterShape_MNK,
|
||||||
|
StageCountType,
|
||||||
|
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>,
|
||||||
|
cute::enable_if_t<
|
||||||
|
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
|
||||||
|
> {
|
||||||
|
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
|
||||||
|
|
||||||
|
static_assert(is_static<TileShape_MNK>::value);
|
||||||
|
static_assert(is_static<ClusterShape_MNK>::value);
|
||||||
|
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||||
|
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||||
|
#endif
|
||||||
|
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||||
|
"Should meet TMA alignment requirement\n");
|
||||||
|
|
||||||
|
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
|
||||||
|
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||||
|
KernelPtrArrayTmaWarpSpecializedPingpong>);
|
||||||
|
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||||
|
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
|
||||||
|
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");
|
||||||
|
|
||||||
|
// For fp32 types, map to tf32 MMA value type
|
||||||
|
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||||
|
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||||
|
|
||||||
|
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
|
||||||
|
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
|
||||||
|
|
||||||
|
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
|
||||||
|
KernelTmaWarpSpecializedCooperative,
|
||||||
|
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||||
|
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
|
||||||
|
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
|
||||||
|
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||||
|
|
||||||
|
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||||
|
ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
||||||
|
|
||||||
|
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||||
|
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||||
|
|
||||||
|
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
||||||
|
GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||||
|
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||||
|
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||||
|
|
||||||
|
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
|
||||||
|
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
||||||
|
|
||||||
|
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
|
||||||
|
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
|
||||||
|
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;
|
||||||
|
|
||||||
|
using SmemCopyAtomA = void;
|
||||||
|
using SmemCopyAtomB = void;
|
||||||
|
|
||||||
|
using CollectiveOp = CollectiveMma<
|
||||||
|
DispatchPolicy,
|
||||||
|
TileShape_MNK,
|
||||||
|
ElementA,
|
||||||
|
TagToStrideA_t<GmemLayoutATag>,
|
||||||
|
ElementB,
|
||||||
|
TagToStrideB_t<GmemLayoutBTag>,
|
||||||
|
TiledMma,
|
||||||
|
GmemTiledCopyA,
|
||||||
|
SmemLayoutAtomA,
|
||||||
|
SmemCopyAtomA,
|
||||||
|
cute::identity,
|
||||||
|
GmemTiledCopyB,
|
||||||
|
SmemLayoutAtomB,
|
||||||
|
SmemCopyAtomB,
|
||||||
|
cute::identity
|
||||||
|
>;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass::gemm::collective
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
183
csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp
Normal file
183
csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
// clang-format off
|
||||||
|
// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp
|
||||||
|
|
||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cute/algorithm/clear.hpp"
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
///////////////////////////////////FP8 Accumulation///////////////////////////
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// This class provides API to promote (add) or scale (multiply_add) the results
|
||||||
|
/// from the tensor core accumulators to the main accumulators when the number
|
||||||
|
/// of MMAs reaches the max number of MMA interval specified by user, after that
|
||||||
|
/// the tensor core accumulators are zeroed.
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass::gemm::collective {
|
||||||
|
|
||||||
|
template <
|
||||||
|
class EngineAccum,
|
||||||
|
class LayoutAccum>
|
||||||
|
struct GmmaFP8AccumulationWithScale {
|
||||||
|
using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>;
|
||||||
|
using ElementAccumulator = typename EngineAccum::value_type;
|
||||||
|
|
||||||
|
static_assert(is_static<LayoutAccum>::value, "Accumulator Layout should be static");
|
||||||
|
static_assert(is_rmem<TensorAccum>::value , "Accumulator tensor must be rmem resident.");
|
||||||
|
|
||||||
|
private:
|
||||||
|
TensorAccum& accum_;
|
||||||
|
TensorAccum accum_temp_;
|
||||||
|
|
||||||
|
uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted.
|
||||||
|
uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop
|
||||||
|
uint32_t mma_count_; // current executed MMAs
|
||||||
|
uint32_t reset_accum_flag_; // accum needs to be zeroed or not.
|
||||||
|
|
||||||
|
// promote or `add` the partial accumulators to main accumulator (FADD).
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void promote_core() {
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(accum_); ++i) {
|
||||||
|
accum_(i) += accum_temp_(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
|
||||||
|
template <
|
||||||
|
class EngineScale,
|
||||||
|
class LayoutScale>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void scale_core(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||||
|
using TensorScale = cute::Tensor<EngineScale, LayoutScale>;
|
||||||
|
|
||||||
|
static_assert(is_static<LayoutScale>::value, "Scale Layout should be static");
|
||||||
|
static_assert(is_rmem<TensorScale>::value , "Scale tensor must be rmem resident.");
|
||||||
|
|
||||||
|
static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape.");
|
||||||
|
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(accum_); ++i) {
|
||||||
|
accum_(i) += accum_temp_(i) * scale(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
GmmaFP8AccumulationWithScale(
|
||||||
|
TensorAccum &accum,
|
||||||
|
uint32_t accum_promotion_interval,
|
||||||
|
uint32_t mma_count_per_mainloop_iteration)
|
||||||
|
: accum_(accum),
|
||||||
|
accum_promotion_interval_(accum_promotion_interval),
|
||||||
|
mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration),
|
||||||
|
mma_count_(0),
|
||||||
|
reset_accum_flag_(0)
|
||||||
|
{
|
||||||
|
accum_temp_ = cute::make_fragment_like(accum);
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods (Common)
|
||||||
|
//
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
TensorAccum& operator()() {
|
||||||
|
return accum_temp_;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// prepare the MMA accumulators when initialization or zeroing is required.
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
bool prepare_if_needed() {
|
||||||
|
return reset_accum_flag_;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods (for FADD version)
|
||||||
|
//
|
||||||
|
|
||||||
|
/// promote (add) the results from the MMA accumulators to main accumulator if needed.
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void promote_if_needed() {
|
||||||
|
mma_count_ += mma_count_per_mainloop_iteration_;
|
||||||
|
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
||||||
|
if (reset_accum_flag_) {
|
||||||
|
promote_core();
|
||||||
|
mma_count_ = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// promote (add) the residue results from the MMA accumulators to main accumulator if needed.
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void promote_residue_if_needed() {
|
||||||
|
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
||||||
|
promote_core();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods (for FFMA version)
|
||||||
|
//
|
||||||
|
|
||||||
|
/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
|
||||||
|
template <
|
||||||
|
class EngineScale,
|
||||||
|
class LayoutScale>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void scale_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||||
|
mma_count_ += mma_count_per_mainloop_iteration_;
|
||||||
|
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
||||||
|
if (reset_accum_flag_) {
|
||||||
|
scale_core(scale);
|
||||||
|
mma_count_ = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
|
||||||
|
template <
|
||||||
|
class EngineScale,
|
||||||
|
class LayoutScale>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void scale_residue_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||||
|
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
||||||
|
scale_core(scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::gemm::collective
|
@ -0,0 +1,730 @@
|
|||||||
|
// clang-format off
|
||||||
|
// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
|
||||||
|
|
||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||||
|
#include "cutlass/trace.h"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
|
||||||
|
#include "cute/arch/cluster_sm90.hpp"
|
||||||
|
#include "cute/arch/copy_sm80.hpp"
|
||||||
|
#include "cute/arch/copy_sm90.hpp"
|
||||||
|
#include "cute/algorithm/functional.hpp"
|
||||||
|
#include "cute/atom/mma_atom.hpp"
|
||||||
|
#include "cute/algorithm/gemm.hpp"
|
||||||
|
#include "cute/tensor_predicate.hpp"
|
||||||
|
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||||
|
|
||||||
|
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||||
|
#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass::gemm::collective {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// WarpSpecialized Mainloop
|
||||||
|
template <
|
||||||
|
int Stages,
|
||||||
|
class ClusterShape,
|
||||||
|
class KernelSchedule,
|
||||||
|
int ScaleGranularityM_,
|
||||||
|
class TileShape_,
|
||||||
|
class ElementA_,
|
||||||
|
class StrideA_,
|
||||||
|
class ElementB_,
|
||||||
|
class StrideB_,
|
||||||
|
class TiledMma_,
|
||||||
|
class GmemTiledCopyA_,
|
||||||
|
class SmemLayoutAtomA_,
|
||||||
|
class SmemCopyAtomA_,
|
||||||
|
class TransformA_,
|
||||||
|
class GmemTiledCopyB_,
|
||||||
|
class SmemLayoutAtomB_,
|
||||||
|
class SmemCopyAtomB_,
|
||||||
|
class TransformB_>
|
||||||
|
struct CollectiveMma<
|
||||||
|
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>,
|
||||||
|
TileShape_,
|
||||||
|
ElementA_,
|
||||||
|
StrideA_,
|
||||||
|
ElementB_,
|
||||||
|
StrideB_,
|
||||||
|
TiledMma_,
|
||||||
|
GmemTiledCopyA_,
|
||||||
|
SmemLayoutAtomA_,
|
||||||
|
SmemCopyAtomA_,
|
||||||
|
TransformA_,
|
||||||
|
GmemTiledCopyB_,
|
||||||
|
SmemLayoutAtomB_,
|
||||||
|
SmemCopyAtomB_,
|
||||||
|
TransformB_>
|
||||||
|
{
|
||||||
|
//
|
||||||
|
// Type Aliases
|
||||||
|
//
|
||||||
|
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>;
|
||||||
|
using TileShape = TileShape_;
|
||||||
|
using ElementA = ElementA_;
|
||||||
|
using StrideA = StrideA_;
|
||||||
|
using ElementB = ElementB_;
|
||||||
|
using StrideB = StrideB_;
|
||||||
|
using TiledMma = TiledMma_;
|
||||||
|
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||||
|
using ElementBlockScale = ElementAccumulator;
|
||||||
|
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||||
|
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||||
|
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||||
|
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||||
|
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||||
|
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||||
|
using TransformA = TransformA_;
|
||||||
|
using TransformB = TransformB_;
|
||||||
|
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||||
|
|
||||||
|
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
||||||
|
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||||
|
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||||
|
using PipelineParams = typename MainloopPipeline::Params;
|
||||||
|
|
||||||
|
// Two threads per CTA are producers (1 for operand tile and 32 for scales)
|
||||||
|
static constexpr int NumProducerThreadEvents = 33;
|
||||||
|
|
||||||
|
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
|
||||||
|
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||||
|
|
||||||
|
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||||
|
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||||
|
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||||
|
|
||||||
|
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||||
|
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||||
|
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||||
|
|
||||||
|
static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
|
||||||
|
|
||||||
|
// Tile along modes in a way that maximizes the TMA box size.
|
||||||
|
using SmemLayoutA = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomA{},
|
||||||
|
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||||
|
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||||
|
using SmemLayoutB = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomB{},
|
||||||
|
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||||
|
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||||
|
|
||||||
|
// Block scaling gmem-to-smem copy atom
|
||||||
|
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||||
|
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||||
|
|
||||||
|
// Block scaling smem layout
|
||||||
|
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
|
||||||
|
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
|
||||||
|
|
||||||
|
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
||||||
|
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
||||||
|
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||||
|
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||||
|
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||||
|
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||||
|
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||||
|
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||||
|
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
||||||
|
"ElementAccumulator and ElementBlockScale should be same datatype");
|
||||||
|
|
||||||
|
struct SharedStorage
|
||||||
|
{
|
||||||
|
struct TensorStorage : cute::aligned_struct<128> {
|
||||||
|
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // mxk
|
||||||
|
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // nxk
|
||||||
|
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleA>> smem_scale_A; // ScaleMsPerTile x k
|
||||||
|
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_B; // 1xk
|
||||||
|
} tensors;
|
||||||
|
|
||||||
|
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||||
|
PipelineStorage pipeline;
|
||||||
|
};
|
||||||
|
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||||
|
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||||
|
|
||||||
|
// Host side kernel arguments
|
||||||
|
struct Arguments {
|
||||||
|
ElementA const* ptr_A;
|
||||||
|
StrideA dA;
|
||||||
|
ElementB const* ptr_B;
|
||||||
|
StrideB dB;
|
||||||
|
ElementBlockScale const* ptr_scale_A;
|
||||||
|
ElementBlockScale const* ptr_scale_B;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Device side kernel params
|
||||||
|
struct Params {
|
||||||
|
// Assumption: StrideA is congruent with Problem_MK
|
||||||
|
using TMA_A = decltype(make_tma_copy_A_sm90(
|
||||||
|
GmemTiledCopyA{},
|
||||||
|
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||||
|
SmemLayoutA{}(_,_,0),
|
||||||
|
TileShape{},
|
||||||
|
ClusterShape{}));
|
||||||
|
// Assumption: StrideB is congruent with Problem_NK
|
||||||
|
using TMA_B = decltype(make_tma_copy_B_sm90(
|
||||||
|
GmemTiledCopyB{},
|
||||||
|
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||||
|
SmemLayoutB{}(_,_,0),
|
||||||
|
TileShape{},
|
||||||
|
ClusterShape{}));
|
||||||
|
TMA_A tma_load_a;
|
||||||
|
TMA_B tma_load_b;
|
||||||
|
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
||||||
|
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
||||||
|
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
||||||
|
// Block scaling factors for A and B
|
||||||
|
ElementBlockScale const* ptr_scale_A;
|
||||||
|
ElementBlockScale const* ptr_scale_B;
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static constexpr Params
|
||||||
|
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||||
|
(void) workspace;
|
||||||
|
|
||||||
|
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||||
|
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||||
|
auto [M,N,K,L] = problem_shape_MNKL;
|
||||||
|
|
||||||
|
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
|
||||||
|
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
|
||||||
|
|
||||||
|
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
||||||
|
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
|
||||||
|
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
|
||||||
|
GmemTiledCopyA{},
|
||||||
|
tensor_a,
|
||||||
|
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||||
|
TileShape{},
|
||||||
|
ClusterShape{});
|
||||||
|
typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
|
||||||
|
GmemTiledCopyB{},
|
||||||
|
tensor_b,
|
||||||
|
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||||
|
TileShape{},
|
||||||
|
ClusterShape{});
|
||||||
|
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
|
||||||
|
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
|
||||||
|
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
|
||||||
|
|
||||||
|
return {
|
||||||
|
tma_load_a,
|
||||||
|
tma_load_b,
|
||||||
|
transaction_bytes,
|
||||||
|
transaction_bytes_mk,
|
||||||
|
transaction_bytes_nk,
|
||||||
|
args.ptr_scale_A,
|
||||||
|
args.ptr_scale_B
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class ProblemShape>
|
||||||
|
static bool
|
||||||
|
can_implement(
|
||||||
|
ProblemShape const& problem_shape,
|
||||||
|
[[maybe_unused]] Arguments const& args) {
|
||||||
|
constexpr int tma_alignment_bits = 128;
|
||||||
|
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||||
|
auto [M,N,K,L] = problem_shape_MNKL;
|
||||||
|
|
||||||
|
bool implementable = true;
|
||||||
|
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||||
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||||
|
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||||
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
||||||
|
|
||||||
|
if (!implementable) {
|
||||||
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||||
|
}
|
||||||
|
return implementable;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||||
|
static constexpr int K_PIPE_MMAS = 1;
|
||||||
|
static constexpr uint32_t TmaTransactionBytesMK =
|
||||||
|
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
|
||||||
|
static constexpr uint32_t TmaTransactionBytesNK =
|
||||||
|
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
|
||||||
|
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
||||||
|
|
||||||
|
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
||||||
|
{
|
||||||
|
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||||
|
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set up the data needed by this collective for load and mma.
|
||||||
|
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||||
|
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||||
|
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||||
|
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||||
|
template <class ProblemShape_MNKL>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
||||||
|
using X = Underscore;
|
||||||
|
// Separate out problem shape for convenience
|
||||||
|
auto [M,N,K,L] = problem_shape_MNKL;
|
||||||
|
|
||||||
|
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||||
|
// Represent the full tensors -- get these from TMA
|
||||||
|
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
|
||||||
|
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
||||||
|
|
||||||
|
// Make tiled views, defer the slice
|
||||||
|
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||||
|
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||||
|
|
||||||
|
constexpr auto scales_m = Int<ScaleMsPerTile>{};
|
||||||
|
auto tM = get<2>(gA_mkl.shape());
|
||||||
|
auto tN = get<2>(gB_nkl.shape());
|
||||||
|
auto tK = get<3>(gA_mkl.shape());
|
||||||
|
|
||||||
|
// Make the tiled views of scale tensors
|
||||||
|
auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l)
|
||||||
|
auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{});
|
||||||
|
auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
|
||||||
|
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
|
||||||
|
|
||||||
|
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
||||||
|
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
|
||||||
|
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
|
||||||
|
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
|
||||||
|
|
||||||
|
return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform a collective-scoped matrix multiply-accumulate
|
||||||
|
/// Producer Perspective
|
||||||
|
template <
|
||||||
|
class TensorA, class TensorB,
|
||||||
|
class TensorScaleA, class TensorScaleB,
|
||||||
|
class KTileIterator, class BlockCoord
|
||||||
|
>
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
load(
|
||||||
|
Params const& mainloop_params,
|
||||||
|
MainloopPipeline pipeline,
|
||||||
|
PipelineState smem_pipe_write,
|
||||||
|
cute::tuple<TensorA, TensorB, TensorScaleA, TensorScaleB> const& load_inputs,
|
||||||
|
BlockCoord const& blk_coord,
|
||||||
|
KTileIterator k_tile_iter, int k_tile_count,
|
||||||
|
int thread_idx,
|
||||||
|
uint32_t block_rank_in_cluster,
|
||||||
|
TensorStorage& shared_tensors) {
|
||||||
|
int lane_predicate = cute::elect_one_sync();
|
||||||
|
|
||||||
|
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
|
||||||
|
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||||
|
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||||
|
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
|
||||||
|
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||||
|
|
||||||
|
//
|
||||||
|
// Prepare the TMA loads for A and B
|
||||||
|
//
|
||||||
|
|
||||||
|
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||||
|
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||||
|
|
||||||
|
Tensor gA_mkl = get<0>(load_inputs);
|
||||||
|
Tensor gB_nkl = get<1>(load_inputs);
|
||||||
|
|
||||||
|
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||||
|
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||||
|
|
||||||
|
// Partition the inputs based on the current block coordinates.
|
||||||
|
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||||
|
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||||
|
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||||
|
|
||||||
|
|
||||||
|
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
|
||||||
|
Tensor mScaleA_mkl = get<2>(load_inputs);
|
||||||
|
Tensor mScaleB_nkl = get<3>(load_inputs);
|
||||||
|
auto scales_m = get<0>(mScaleA_mkl.shape());
|
||||||
|
|
||||||
|
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
|
||||||
|
|
||||||
|
Tensor gScaleA = local_tile(
|
||||||
|
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||||
|
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
|
||||||
|
Tensor cScaleA = local_tile(
|
||||||
|
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||||
|
make_coord(m_coord,_,l_coord));
|
||||||
|
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
|
||||||
|
|
||||||
|
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
|
||||||
|
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
|
||||||
|
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
|
||||||
|
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
|
||||||
|
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
||||||
|
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
||||||
|
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
|
||||||
|
|
||||||
|
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
|
||||||
|
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
|
||||||
|
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
|
||||||
|
|
||||||
|
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
|
||||||
|
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
|
||||||
|
|
||||||
|
// Applies the mapping from block_tma_a
|
||||||
|
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||||
|
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||||
|
|
||||||
|
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||||
|
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||||
|
|
||||||
|
uint16_t mcast_mask_a = 0;
|
||||||
|
uint16_t mcast_mask_b = 0;
|
||||||
|
|
||||||
|
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
|
||||||
|
// Maps the tile -> block, value
|
||||||
|
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||||
|
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||||
|
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||||
|
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||||
|
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||||
|
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||||
|
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate predicate tensors for a_scales (since we can't guarantee that
|
||||||
|
// all scales are valid, since we could have a partial tiles along M)
|
||||||
|
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < size(tApA_ScaleA); ++i) {
|
||||||
|
tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mainloop
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
for ( ; k_tile_count > 0; --k_tile_count) {
|
||||||
|
// LOCK smem_pipe_write for _writing_
|
||||||
|
pipeline.producer_acquire(smem_pipe_write);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Copy gmem to smem for *k_tile_iter
|
||||||
|
//
|
||||||
|
int write_stage = smem_pipe_write.index();
|
||||||
|
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||||
|
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||||
|
|
||||||
|
// Copy operands A and B from global memory to shared memory
|
||||||
|
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||||
|
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||||
|
|
||||||
|
// Copy scale tensors from global memory to shared memory
|
||||||
|
copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
|
||||||
|
copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage));
|
||||||
|
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
|
||||||
|
|
||||||
|
++k_tile_iter;
|
||||||
|
|
||||||
|
// Advance smem_pipe_write
|
||||||
|
++smem_pipe_write;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
load_tail(
|
||||||
|
MainloopPipeline pipeline,
|
||||||
|
PipelineState smem_pipe_write) {
|
||||||
|
int lane_predicate = cute::elect_one_sync();
|
||||||
|
|
||||||
|
// Issue the epilogue waits
|
||||||
|
if (lane_predicate) {
|
||||||
|
/* This helps avoid early exit of blocks in Cluster
|
||||||
|
* Waits for all stages to either be released (all
|
||||||
|
* Consumer UNLOCKs), or if the stage was never used
|
||||||
|
* then would just be acquired since the phase was
|
||||||
|
* still inverted from make_producer_start_state
|
||||||
|
*/
|
||||||
|
pipeline.producer_tail(smem_pipe_write);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform a collective-scoped matrix multiply-accumulate
|
||||||
|
/// Consumer Perspective
|
||||||
|
template <
|
||||||
|
class FrgTensorC
|
||||||
|
>
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
mma(MainloopPipeline pipeline,
|
||||||
|
PipelineState smem_pipe_read,
|
||||||
|
FrgTensorC& accum,
|
||||||
|
int k_tile_count,
|
||||||
|
int thread_idx,
|
||||||
|
TensorStorage& shared_tensors,
|
||||||
|
Params const& mainloop_params) {
|
||||||
|
|
||||||
|
|
||||||
|
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||||
|
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||||
|
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||||
|
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||||
|
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||||
|
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||||
|
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||||
|
|
||||||
|
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||||
|
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||||
|
|
||||||
|
// Block scaling
|
||||||
|
Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()),
|
||||||
|
Layout<
|
||||||
|
Shape<Shape<Int<ScaleGranularityM>, Int<ScaleMsPerTile>>, cute::tuple_element_t<1, TileShape>, Int<DispatchPolicy::Stages>>,
|
||||||
|
Stride<Stride<_0, _1>, _0, Int<ScaleMsPerTile>>
|
||||||
|
>{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k)
|
||||||
|
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||||
|
|
||||||
|
//
|
||||||
|
// Define C accumulators and A/B partitioning
|
||||||
|
//
|
||||||
|
|
||||||
|
// Layout of warp group to thread mapping
|
||||||
|
|
||||||
|
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
|
||||||
|
stride<0>(typename TiledMma::BLayout{}) == 0 and
|
||||||
|
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
|
||||||
|
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
|
||||||
|
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
|
||||||
|
|
||||||
|
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
|
||||||
|
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
|
||||||
|
Int<NumThreadsPerWarpGroup>{});
|
||||||
|
|
||||||
|
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
|
||||||
|
|
||||||
|
TiledMma tiled_mma;
|
||||||
|
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
|
||||||
|
|
||||||
|
Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
|
||||||
|
|
||||||
|
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||||
|
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||||
|
|
||||||
|
// Allocate "fragments/descriptors"
|
||||||
|
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||||
|
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||||
|
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
||||||
|
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||||
|
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||||
|
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||||
|
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||||
|
|
||||||
|
//
|
||||||
|
// PIPELINED MAIN LOOP
|
||||||
|
//
|
||||||
|
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||||
|
"ERROR : Incorrect number of MMAs in flight");
|
||||||
|
|
||||||
|
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||||
|
PipelineState smem_pipe_release = smem_pipe_read;
|
||||||
|
|
||||||
|
// Per block scale values for operand A and B
|
||||||
|
|
||||||
|
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
|
||||||
|
using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above
|
||||||
|
|
||||||
|
Tensor tCrScaleAViewAsC = make_tensor<ElementBlockScale>(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N)
|
||||||
|
ElementBlockScale scale_b;
|
||||||
|
|
||||||
|
// Prologue GMMAs
|
||||||
|
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||||
|
|
||||||
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||||
|
|
||||||
|
GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA));
|
||||||
|
warpgroup_fence_operand(accumulation());
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||||
|
{
|
||||||
|
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||||
|
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||||
|
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||||
|
|
||||||
|
if (accumulation.prepare_if_needed()) {
|
||||||
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||||
|
}
|
||||||
|
|
||||||
|
int read_stage = smem_pipe_read.index();
|
||||||
|
|
||||||
|
// Load per block scale values from shared memory to registers.
|
||||||
|
scale_b = sScaleB[read_stage];
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||||
|
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||||
|
}
|
||||||
|
if constexpr (ScaleMsPerTile == 1) {
|
||||||
|
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||||
|
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||||
|
} else {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||||
|
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
warpgroup_arrive();
|
||||||
|
// Unroll the K mode manually to set scale D to 1
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||||
|
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||||
|
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||||
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||||
|
}
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||||
|
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||||
|
|
||||||
|
++smem_pipe_read;
|
||||||
|
}
|
||||||
|
|
||||||
|
warpgroup_fence_operand(accumulation());
|
||||||
|
// Mainloop GMMAs
|
||||||
|
k_tile_count -= prologue_mma_count;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
for ( ; k_tile_count > 0; --k_tile_count)
|
||||||
|
{
|
||||||
|
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||||
|
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||||
|
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Compute on k_tile
|
||||||
|
//
|
||||||
|
|
||||||
|
int read_stage = smem_pipe_read.index();
|
||||||
|
|
||||||
|
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
|
||||||
|
scale_b = sScaleB[read_stage];
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||||
|
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||||
|
}
|
||||||
|
if constexpr (ScaleMsPerTile == 1) {
|
||||||
|
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||||
|
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||||
|
} else {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||||
|
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (accumulation.prepare_if_needed()) {
|
||||||
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||||
|
}
|
||||||
|
|
||||||
|
warpgroup_fence_operand(accumulation());
|
||||||
|
warpgroup_arrive();
|
||||||
|
// Unroll the K mode manually to set scale D to 1
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||||
|
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||||
|
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||||
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||||
|
}
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||||
|
warpgroup_wait<K_PIPE_MMAS>();
|
||||||
|
warpgroup_fence_operand(accumulation());
|
||||||
|
|
||||||
|
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||||
|
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||||
|
|
||||||
|
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||||
|
|
||||||
|
// Advance smem_pipe_read and smem_pipe_release
|
||||||
|
++smem_pipe_read;
|
||||||
|
++smem_pipe_release;
|
||||||
|
}
|
||||||
|
|
||||||
|
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
||||||
|
|
||||||
|
warpgroup_fence_operand(accumulation());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform a Consumer Epilogue to release all buffers
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
||||||
|
// Prologue GMMAs
|
||||||
|
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||||
|
k_tile_count -= prologue_mma_count;
|
||||||
|
|
||||||
|
smem_pipe_release.advance(k_tile_count);
|
||||||
|
|
||||||
|
// Wait on all GMMAs to complete
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
|
||||||
|
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||||
|
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||||
|
++smem_pipe_release;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass::gemm::collective
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
39
csrc/cutlass_extensions/gemm/dispatch_policy.hpp
Normal file
39
csrc/cutlass_extensions/gemm/dispatch_policy.hpp
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::gemm {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// FP8 related policies (including Blocked Scaled Accumulation)
|
||||||
|
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
|
||||||
|
// `ScaleGranularityM` indicates that scaling granularity is
|
||||||
|
// `size<0>(TileShape_MNK{})` along M.
|
||||||
|
template <int ScaleGranularityM = 0>
|
||||||
|
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
|
||||||
|
: KernelTmaWarpSpecializedCooperative {};
|
||||||
|
|
||||||
|
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
|
||||||
|
// specialized dynamic schedule For FP8 kernels with Block Scaling
|
||||||
|
template <int Stages_, class ClusterShape_ = Shape<_1, _1, _1>,
|
||||||
|
class KernelSchedule = KernelTmaWarpSpecialized,
|
||||||
|
int ScaleGranularityM =
|
||||||
|
0 // `ScaleGranularityM` specifies scaling granularity along M,
|
||||||
|
// while zero-value `ScaleGranularityM` indicates that scaling
|
||||||
|
// granularity is `size<0>(TileShape_MNK{})` along M.
|
||||||
|
>
|
||||||
|
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
|
||||||
|
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_,
|
||||||
|
KernelSchedule> {
|
||||||
|
static_assert(
|
||||||
|
cute::is_same_v<
|
||||||
|
KernelSchedule,
|
||||||
|
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
|
||||||
|
ScaleGranularityM>>,
|
||||||
|
"KernelSchedule must be one of the warp specialized policies");
|
||||||
|
};
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass::gemm
|
@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
namespace cutlass::gemm::collective {
|
namespace cutlass::gemm::collective {
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
93
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
Normal file
93
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
// clang-format will break include orders
|
||||||
|
// clang-format off
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cute/atom/mma_atom.hpp"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||||
|
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||||
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
|
#include "core/math.hpp"
|
||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
namespace vllm::c3x {
|
||||||
|
|
||||||
|
static inline cute::Shape<int, int, int, int> get_problem_shape(
|
||||||
|
torch::Tensor const& a, torch::Tensor const& b) {
|
||||||
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||||
|
return {m, n, k, 1};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename GemmKernel>
|
||||||
|
void cutlass_gemm_caller(torch::Device device,
|
||||||
|
cute::Shape<int, int, int, int> prob_shape,
|
||||||
|
typename GemmKernel::MainloopArguments mainloop_args,
|
||||||
|
typename GemmKernel::EpilogueArguments epilogue_args) {
|
||||||
|
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||||
|
prob_shape, mainloop_args, epilogue_args};
|
||||||
|
|
||||||
|
// Launch the CUTLASS GEMM kernel.
|
||||||
|
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
GemmOp gemm_op;
|
||||||
|
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||||
|
|
||||||
|
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||||
|
auto const workspace_options =
|
||||||
|
torch::TensorOptions().dtype(torch::kUInt8).device(device);
|
||||||
|
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||||
|
|
||||||
|
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||||
|
CUTLASS_CHECK(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Gemm, typename... EpilogueArgs>
|
||||||
|
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
EpilogueArgs&&... epilogue_params) {
|
||||||
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
|
|
||||||
|
int64_t lda = a.stride(0);
|
||||||
|
int64_t ldb = b.stride(1);
|
||||||
|
int64_t ldc = out.stride(0);
|
||||||
|
|
||||||
|
using StrideA = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
||||||
|
using StrideB = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
||||||
|
using StrideC = typename Gemm::StrideC;
|
||||||
|
|
||||||
|
StrideA a_stride{lda, cute::Int<1>{}, 0};
|
||||||
|
StrideB b_stride{ldb, cute::Int<1>{}, 0};
|
||||||
|
StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}};
|
||||||
|
|
||||||
|
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
|
||||||
|
|
||||||
|
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||||
|
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||||
|
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||||
|
b_stride};
|
||||||
|
|
||||||
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
|
Gemm::Epilogue::prepare_args(
|
||||||
|
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||||
|
c_ptr, c_stride, c_ptr, c_stride};
|
||||||
|
|
||||||
|
cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||||
|
epilogue_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm::c3x
|
@ -2,9 +2,6 @@
|
|||||||
|
|
||||||
// clang-format will break include orders
|
// clang-format will break include orders
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include <torch/all.h>
|
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
@ -32,21 +29,6 @@ using namespace cute;
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
|
||||||
// architectures that will never use the kernel. The purpose of this is to
|
|
||||||
// reduce the size of the compiled binary.
|
|
||||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
|
||||||
// into code that will be executed on the device where it is defined.
|
|
||||||
template <typename Kernel>
|
|
||||||
struct enable_sm90_or_later : Kernel {
|
|
||||||
template <typename... Args>
|
|
||||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
|
||||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
|
||||||
Kernel::operator()(std::forward<Args>(args)...);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename ElementAB_, typename ElementD_,
|
template <typename ElementAB_, typename ElementD_,
|
||||||
template <typename, typename, typename> typename Epilogue_,
|
template <typename, typename, typename> typename Epilogue_,
|
||||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||||
@ -101,60 +83,4 @@ struct cutlass_3x_gemm {
|
|||||||
struct GemmKernel : public KernelType {};
|
struct GemmKernel : public KernelType {};
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Gemm, typename... EpilogueArgs>
|
|
||||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
EpilogueArgs&&... epilogue_params) {
|
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
|
||||||
using ElementD = typename Gemm::ElementD;
|
|
||||||
|
|
||||||
int32_t m = a.size(0);
|
|
||||||
int32_t n = b.size(1);
|
|
||||||
int32_t k = a.size(1);
|
|
||||||
|
|
||||||
int64_t lda = a.stride(0);
|
|
||||||
int64_t ldb = b.stride(1);
|
|
||||||
int64_t ldc = out.stride(0);
|
|
||||||
|
|
||||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
|
||||||
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
|
||||||
using StrideC = typename Gemm::StrideC;
|
|
||||||
|
|
||||||
StrideA a_stride{lda, Int<1>{}, 0};
|
|
||||||
StrideB b_stride{ldb, Int<1>{}, 0};
|
|
||||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
|
||||||
|
|
||||||
using GemmKernel = typename Gemm::GemmKernel;
|
|
||||||
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
|
||||||
|
|
||||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
|
||||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
|
||||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
|
||||||
b_stride};
|
|
||||||
|
|
||||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
|
||||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
|
||||||
Gemm::Epilogue::prepare_args(
|
|
||||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
|
||||||
c_ptr, c_stride, c_ptr, c_stride};
|
|
||||||
|
|
||||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
|
||||||
prob_shape, mainloop_args, epilogue_args};
|
|
||||||
|
|
||||||
// Launch the CUTLASS GEMM kernel.
|
|
||||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
|
||||||
GemmOp gemm_op;
|
|
||||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
|
||||||
|
|
||||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
|
||||||
auto const workspace_options =
|
|
||||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
|
||||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
|
||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
|
||||||
|
|
||||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
|
||||||
CUTLASS_CHECK(status);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
@ -0,0 +1,24 @@
|
|||||||
|
#include "scaled_mm_kernels.hpp"
|
||||||
|
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::Tensor> const& azp,
|
||||||
|
std::optional<torch::Tensor> const& bias) {
|
||||||
|
if (azp) {
|
||||||
|
return cutlass_scaled_mm_sm90_int8_epilogue<
|
||||||
|
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
|
||||||
|
*azp, bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
@ -0,0 +1,24 @@
|
|||||||
|
|
||||||
|
#include "scaled_mm_kernels.hpp"
|
||||||
|
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
||||||
|
torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
@ -0,0 +1,168 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cutlass/tensor_ref.h"
|
||||||
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||||
|
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||||
|
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||||
|
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||||
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
|
||||||
|
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||||
|
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
|
#include "cutlass_gemm_caller.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
|
||||||
|
int TileSizeM_ = 128, class ClusterShape = Shape<_1, _2, _1>>
|
||||||
|
struct cutlass_3x_gemm_fp8_blockwise {
|
||||||
|
using GroupSizeM = Int<GroupSizeM_>;
|
||||||
|
using GroupSizeN = Int<GroupSizeN_>;
|
||||||
|
using GroupSizeK = Int<GroupSizeK_>;
|
||||||
|
using TileSizeM = Int<TileSizeM_>;
|
||||||
|
|
||||||
|
static_assert(TileSizeM_ % GroupSizeM_ == 0,
|
||||||
|
"TileSizeM must be a multiple of GroupSizeM");
|
||||||
|
|
||||||
|
using ElementAB = cutlass::float_e4m3_t;
|
||||||
|
|
||||||
|
using ElementA = ElementAB;
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||||
|
|
||||||
|
using ElementB = ElementAB;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||||
|
|
||||||
|
using ElementD = OutType;
|
||||||
|
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||||
|
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
|
|
||||||
|
using ElementC = void;
|
||||||
|
using StrideC = StrideD;
|
||||||
|
static constexpr int AlignmentC = AlignmentD;
|
||||||
|
|
||||||
|
using ElementAccumulator = float;
|
||||||
|
using ElementBlockScale = float;
|
||||||
|
using ElementCompute = float;
|
||||||
|
using ArchTag = cutlass::arch::Sm90;
|
||||||
|
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||||
|
using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>;
|
||||||
|
|
||||||
|
using KernelSchedule = cutlass::gemm::
|
||||||
|
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
|
||||||
|
GroupSizeM_>;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||||
|
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||||
|
|
||||||
|
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
|
||||||
|
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||||
|
|
||||||
|
using CollectiveEpilogue =
|
||||||
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||||
|
ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC,
|
||||||
|
ElementD, StrideD, AlignmentD, EpilogueSchedule,
|
||||||
|
StoreEpilogueCompute>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMainloop =
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
|
||||||
|
LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
|
KernelSchedule>::CollectiveOp;
|
||||||
|
|
||||||
|
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||||
|
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||||
|
cutlass::gemm::PersistentScheduler>>;
|
||||||
|
|
||||||
|
struct GemmKernel : public KernelType {};
|
||||||
|
|
||||||
|
using StrideA = typename GemmKernel::StrideA;
|
||||||
|
using StrideB = typename GemmKernel::StrideB;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Gemm>
|
||||||
|
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
|
|
||||||
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
|
auto prob_shape = c3x::get_problem_shape(a, b);
|
||||||
|
int32_t m = get<0>(prob_shape), n = get<1>(prob_shape),
|
||||||
|
k = get<2>(prob_shape);
|
||||||
|
|
||||||
|
int64_t lda = a.stride(0);
|
||||||
|
int64_t ldb = b.stride(1);
|
||||||
|
int64_t ldc = out.stride(0);
|
||||||
|
|
||||||
|
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||||
|
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
||||||
|
using StrideC = typename Gemm::StrideC;
|
||||||
|
|
||||||
|
StrideA a_stride{lda, Int<1>{}, 0};
|
||||||
|
StrideB b_stride{ldb, Int<1>{}, 0};
|
||||||
|
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||||
|
|
||||||
|
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||||
|
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||||
|
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
||||||
|
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
||||||
|
|
||||||
|
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
|
||||||
|
// being 1 (i.e. a row or column vector)
|
||||||
|
auto is_contiguous_vector = [](const torch::Tensor& t) {
|
||||||
|
auto t_sizes = t.sizes();
|
||||||
|
return t.is_contiguous() &&
|
||||||
|
(t.dim() == 1 ||
|
||||||
|
(t.dim() == 2 &&
|
||||||
|
*std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
|
||||||
|
// we don't have to deal with enforcing implicit layouts
|
||||||
|
TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value);
|
||||||
|
TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value);
|
||||||
|
TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales),
|
||||||
|
"a_scales must be M major");
|
||||||
|
TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value);
|
||||||
|
TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value);
|
||||||
|
TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales),
|
||||||
|
"b_scales must be K major");
|
||||||
|
typename GemmKernel::MainloopArguments mainloop_args{
|
||||||
|
a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr};
|
||||||
|
|
||||||
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
|
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||||
|
|
||||||
|
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||||
|
epilogue_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename OutType>
|
||||||
|
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
|
||||||
|
torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
cutlass_gemm_caller_blockwise<
|
||||||
|
cutlass_3x_gemm_fp8_blockwise<OutType, 1, 128, 128>>(out, a, b, a_scales,
|
||||||
|
b_scales);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
33
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
Normal file
33
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::Tensor> const& azp,
|
||||||
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
||||||
|
torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
|
} // namespace vllm
|
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
Normal file
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
#include "scaled_mm_kernels.hpp"
|
||||||
|
#include "scaled_mm_sm90_fp8_dispatch.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
std::optional<torch::Tensor> const& bias) {
|
||||||
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
if (bias) {
|
||||||
|
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||||
|
"currently bias dtype must match output dtype ", out.dtype());
|
||||||
|
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||||
|
out, a, b, a_scales, b_scales, *bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "scaled_mm_c3x.cuh"
|
#include "scaled_mm.cuh"
|
||||||
|
#include "cutlass_gemm_caller.cuh"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
|
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
|
||||||
@ -9,6 +10,8 @@
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
|
using c3x::cutlass_gemm_caller;
|
||||||
|
|
||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename, typename> typename Epilogue>
|
template <typename, typename, typename> typename Epilogue>
|
||||||
struct sm90_fp8_config_default {
|
struct sm90_fp8_config_default {
|
||||||
@ -93,4 +96,25 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <template <typename, typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::Tensor& out,
|
||||||
|
torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
EpilogueArgs&&... epilogue_args) {
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
|
cutlass::bfloat16_t, Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
|
cutlass::half_t, Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
Normal file
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
#include "scaled_mm_kernels.hpp"
|
||||||
|
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
std::optional<torch::Tensor> const& bias) {
|
||||||
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
if (bias) {
|
||||||
|
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||||
|
"currently bias dtype must match output dtype ", out.dtype());
|
||||||
|
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
|
||||||
|
out, a, b, a_scales, b_scales, *bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "scaled_mm_c3x.cuh"
|
#include "scaled_mm.cuh"
|
||||||
|
#include "cutlass_gemm_caller.cuh"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This file defines Gemm kernel configurations for SM90 (int8) based on the
|
* This file defines Gemm kernel configurations for SM90 (int8) based on the
|
||||||
@ -9,6 +10,8 @@
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
|
using c3x::cutlass_gemm_caller;
|
||||||
|
|
||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename, typename> typename Epilogue>
|
template <typename, typename, typename> typename Epilogue>
|
||||||
struct sm90_int8_config_default {
|
struct sm90_int8_config_default {
|
||||||
@ -137,4 +140,24 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <template <typename, typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
void cutlass_scaled_mm_sm90_int8_epilogue(torch::Tensor& out,
|
||||||
|
torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
EpilogueArgs&&... epilogue_args) {
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||||
|
Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
@ -1,52 +1,13 @@
|
|||||||
#include <cudaTypedefs.h>
|
#include <cudaTypedefs.h>
|
||||||
|
#include "c3x/scaled_mm_kernels.hpp"
|
||||||
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
#include "core/math.hpp"
|
||||||
|
|
||||||
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
|
|
||||||
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
|
|
||||||
|
|
||||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
|
||||||
using namespace vllm;
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||||
NVIDIA GPUs with sm90a (Hopper) or later.
|
NVIDIA GPUs with sm90a (Hopper) or later.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
template <template <typename, typename, typename> typename Epilogue,
|
|
||||||
typename... EpilogueArgs>
|
|
||||||
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
EpilogueArgs&&... epilogue_args) {
|
|
||||||
if (a.dtype() == torch::kInt8) {
|
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
|
||||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
|
||||||
Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
||||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
|
||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
|
||||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
|
||||||
cutlass::bfloat16_t, Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
||||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
|
||||||
cutlass::half_t, Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
@ -54,14 +15,50 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
std::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
if (bias) {
|
|
||||||
TORCH_CHECK(bias->dtype() == c.dtype(),
|
using GroupShape = std::array<int64_t, 2>;
|
||||||
"currently bias dtype must match output dtype ", c.dtype());
|
|
||||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
|
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||||
c, a, b, a_scales, b_scales, *bias);
|
|
||||||
|
GroupShape a_scale_group_shape = [&, &s = a_scales]() -> GroupShape {
|
||||||
|
if (s.numel() == 1) return {M, K}; // tensor-wise
|
||||||
|
if (s.dim() == 2)
|
||||||
|
return {ceil_div(a.size(0), s.size(0)), ceil_div(a.size(1), s.size(1))};
|
||||||
|
TORCH_CHECK(false, "Unsupported scale shape for scale_a");
|
||||||
|
}();
|
||||||
|
|
||||||
|
GroupShape b_scale_group_shape = [&, &s = b_scales]() -> GroupShape {
|
||||||
|
if (s.numel() == 1) return {K, N}; // tensor-wise
|
||||||
|
if (s.dim() == 2)
|
||||||
|
return {ceil_div(b.size(0), s.size(0)), ceil_div(b.size(1), s.size(1))};
|
||||||
|
TORCH_CHECK(false, "Unsupported scale shape for scale_b");
|
||||||
|
}();
|
||||||
|
|
||||||
|
if ((a_scale_group_shape == GroupShape{M, K} ||
|
||||||
|
a_scale_group_shape == GroupShape{1, K}) &&
|
||||||
|
(b_scale_group_shape == GroupShape{K, N} ||
|
||||||
|
b_scale_group_shape == GroupShape{K, 1})) {
|
||||||
|
// "standard per-tensor/per-token/per-channel" scaling
|
||||||
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||||
|
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
|
||||||
} else {
|
} else {
|
||||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
|
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||||
c, a, b, a_scales, b_scales);
|
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
|
||||||
|
}
|
||||||
|
} else if (a_scale_group_shape == GroupShape{1, 128} &&
|
||||||
|
b_scale_group_shape == GroupShape{128, 128}) {
|
||||||
|
// 1x128 per-token group scales for activations
|
||||||
|
// 128x128 blockwise scales for weights
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn &&
|
||||||
|
b.dtype() == torch::kFloat8_e4m3fn,
|
||||||
|
"Currently only FP8 is supported for A group shape 1x128 and "
|
||||||
|
"B group shape 128x128");
|
||||||
|
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||||
|
|
||||||
|
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported scale group shapes for CUTLASS 3.x GEMM");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,13 +72,6 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
if (azp) {
|
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
||||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
|
azp, bias);
|
||||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
|
||||||
} else {
|
|
||||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
|
||||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
@ -89,15 +89,12 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||||
b.size(1) == c.size(1));
|
b.size(1) == c.size(1));
|
||||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
|
||||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
|
||||||
|
|
||||||
// Check for strides and alignment
|
// Check for strides and alignment
|
||||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
|
||||||
|
|
||||||
if (bias) {
|
if (bias) {
|
||||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||||
|
@ -272,6 +272,10 @@ struct MacheteCollectiveMma {
|
|||||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||||
|
|
||||||
using PipelineParams = typename MainloopPipeline::Params;
|
using PipelineParams = typename MainloopPipeline::Params;
|
||||||
|
|
||||||
|
// One threads per CTA are producers (1 for operand tile)
|
||||||
|
static constexpr int NumProducerThreadEvents = 1;
|
||||||
|
|
||||||
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}),
|
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}),
|
||||||
shape<1>(SmemLayoutAtomScale{})));
|
shape<1>(SmemLayoutAtomScale{})));
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ import torch
|
|||||||
from tests.kernels.utils import opcheck
|
from tests.kernels.utils import opcheck
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
from .utils import baseline_scaled_mm, to_fp8, to_int8
|
from .utils import baseline_scaled_mm, to_fp8, to_int8
|
||||||
|
|
||||||
@ -39,6 +40,11 @@ CUDA_DEVICES = [
|
|||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# -1 means full extent in that dimension
|
||||||
|
TENSORWISE_GROUP_SHAPE = (-1, -1)
|
||||||
|
PER_TOKEN_GROUP_SHAPE = (1, -1)
|
||||||
|
PER_OUT_CH_GROUP_SHAPE = (-1, 1)
|
||||||
|
|
||||||
capability = current_platform.get_device_capability()
|
capability = current_platform.get_device_capability()
|
||||||
capability = capability[0] * 10 + capability[1]
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
|
||||||
@ -47,11 +53,22 @@ def rand_int8(shape: tuple, device: str = "cuda"):
|
|||||||
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
||||||
|
|
||||||
|
|
||||||
|
def group_scale_helper(shape, group_shape):
|
||||||
|
return [shape[i] if s < 0 else s for i, s in enumerate(group_shape)]
|
||||||
|
|
||||||
|
|
||||||
|
def scale_shape(shape, group_shape):
|
||||||
|
assert len(shape) == len(group_shape)
|
||||||
|
group_shape = group_scale_helper(shape, group_shape)
|
||||||
|
return tuple(
|
||||||
|
cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp8_gemm_helper(m: int,
|
def cutlass_fp8_gemm_helper(m: int,
|
||||||
n: int,
|
n: int,
|
||||||
k: int,
|
k: int,
|
||||||
per_token_act_quant: bool,
|
a_scale_group_shape: tuple,
|
||||||
per_out_channel_weight_quant: bool,
|
b_scale_group_shape: tuple,
|
||||||
use_bias: bool,
|
use_bias: bool,
|
||||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
||||||
device: str = "cuda"):
|
device: str = "cuda"):
|
||||||
@ -60,13 +77,17 @@ def cutlass_fp8_gemm_helper(m: int,
|
|||||||
a = to_fp8(torch.randn((m, k), device=device))
|
a = to_fp8(torch.randn((m, k), device=device))
|
||||||
b = to_fp8(torch.randn((n, k), device=device).t())
|
b = to_fp8(torch.randn((n, k), device=device).t())
|
||||||
|
|
||||||
m_a_scales = m if per_token_act_quant else 1
|
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||||
n_b_scales = n if per_out_channel_weight_quant else 1
|
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||||
|
|
||||||
|
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
|
||||||
|
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
|
||||||
|
|
||||||
|
# make scales M-major for blockwise quant, doesn't affect 1D scales
|
||||||
|
scale_a = scale_a.t().contiguous().t()
|
||||||
|
# make scales K-major for blockwise quant, doesn't affect 1D scales
|
||||||
|
scale_b = scale_b.t().contiguous().t()
|
||||||
|
|
||||||
scale_a = (torch.randn((m_a_scales, 1), device=device,
|
|
||||||
dtype=torch.float32))
|
|
||||||
scale_b = (torch.randn((1, n_b_scales), device=device,
|
|
||||||
dtype=torch.float32))
|
|
||||||
if use_bias:
|
if use_bias:
|
||||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||||
else:
|
else:
|
||||||
@ -84,8 +105,8 @@ def cutlass_fp8_gemm_helper(m: int,
|
|||||||
def cutlass_int8_gemm_helper(m: int,
|
def cutlass_int8_gemm_helper(m: int,
|
||||||
n: int,
|
n: int,
|
||||||
k: int,
|
k: int,
|
||||||
per_token_act_quant: bool,
|
a_scale_group_shape: tuple,
|
||||||
per_out_channel_weight_quant: bool,
|
b_scale_group_shape: tuple,
|
||||||
use_bias: bool,
|
use_bias: bool,
|
||||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
||||||
device: str = "cuda"):
|
device: str = "cuda"):
|
||||||
@ -94,13 +115,11 @@ def cutlass_int8_gemm_helper(m: int,
|
|||||||
a = to_int8(torch.randn((m, k), device=device) * 5)
|
a = to_int8(torch.randn((m, k), device=device) * 5)
|
||||||
b = to_int8(torch.randn((n, k), device=device).t() * 5)
|
b = to_int8(torch.randn((n, k), device=device).t() * 5)
|
||||||
|
|
||||||
m_a_scales = m if per_token_act_quant else 1
|
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||||
n_b_scales = n if per_out_channel_weight_quant else 1
|
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||||
|
|
||||||
scale_a = (torch.randn((m_a_scales, 1), device=device,
|
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
|
||||||
dtype=torch.float32))
|
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
|
||||||
scale_b = (torch.randn((1, n_b_scales), device=device,
|
|
||||||
dtype=torch.float32))
|
|
||||||
|
|
||||||
if use_bias:
|
if use_bias:
|
||||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||||
@ -117,82 +136,135 @@ def cutlass_int8_gemm_helper(m: int,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("a_scale_group_shape",
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
|
@pytest.mark.parametrize("b_scale_group_shape",
|
||||||
|
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
@pytest.mark.parametrize("use_bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape,
|
||||||
per_out_ch: bool, use_bias: bool):
|
b_scale_group_shape, use_bias: bool):
|
||||||
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
|
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||||
|
use_bias)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
[((1, 128), (128, 128))])
|
||||||
|
@pytest.mark.parametrize("use_bias", [False])
|
||||||
|
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||||
|
reason="FP8 blockwise is not supported on this GPU type.")
|
||||||
|
def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
|
||||||
|
a_scale_group_shape,
|
||||||
|
b_scale_group_shape, use_bias: bool):
|
||||||
|
if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
|
||||||
|
return
|
||||||
|
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
|
||||||
|
return
|
||||||
|
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||||
|
use_bias)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("a_scale_group_shape",
|
||||||
|
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
|
@pytest.mark.parametrize("b_scale_group_shape",
|
||||||
|
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
@pytest.mark.parametrize("use_bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
|
||||||
per_out_ch: bool, use_bias: bool):
|
b_scale_group_shape, use_bias: bool):
|
||||||
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
|
cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||||
|
use_bias)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("a_scale_group_shape",
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
|
@pytest.mark.parametrize("b_scale_group_shape",
|
||||||
|
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||||
@pytest.mark.parametrize("use_bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
|
||||||
|
b_scale_group_shape,
|
||||||
out_dtype: Type[torch.dtype],
|
out_dtype: Type[torch.dtype],
|
||||||
use_bias: bool):
|
use_bias: bool):
|
||||||
cutlass_int8_gemm_helper(512,
|
cutlass_int8_gemm_helper(512,
|
||||||
512,
|
512,
|
||||||
512,
|
512,
|
||||||
per_act_token,
|
a_scale_group_shape,
|
||||||
per_out_ch,
|
b_scale_group_shape,
|
||||||
use_bias,
|
use_bias,
|
||||||
out_dtype=out_dtype)
|
out_dtype=out_dtype)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("a_scale_group_shape",
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
|
@pytest.mark.parametrize("b_scale_group_shape",
|
||||||
|
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||||
@pytest.mark.parametrize("use_bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
|
||||||
|
b_scale_group_shape,
|
||||||
out_dtype: Type[torch.dtype],
|
out_dtype: Type[torch.dtype],
|
||||||
use_bias: bool):
|
use_bias: bool):
|
||||||
cutlass_fp8_gemm_helper(512,
|
cutlass_fp8_gemm_helper(512,
|
||||||
512,
|
512,
|
||||||
512,
|
512,
|
||||||
per_act_token,
|
a_scale_group_shape,
|
||||||
per_out_ch,
|
b_scale_group_shape,
|
||||||
use_bias,
|
use_bias,
|
||||||
out_dtype=out_dtype)
|
out_dtype=out_dtype)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
[((1, 128), (128, 128))])
|
||||||
|
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||||
|
@pytest.mark.parametrize("use_bias", [False])
|
||||||
|
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||||
|
reason="FP8 blockwise is not supported on this GPU type.")
|
||||||
|
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
|
||||||
|
b_scale_group_shape,
|
||||||
|
out_dtype: Type[torch.dtype],
|
||||||
|
use_bias: bool):
|
||||||
|
cutlass_fp8_gemm_helper(512,
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
a_scale_group_shape,
|
||||||
|
b_scale_group_shape,
|
||||||
|
use_bias,
|
||||||
|
out_dtype=out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("a_scale_group_shape",
|
||||||
|
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
|
@pytest.mark.parametrize("b_scale_group_shape",
|
||||||
|
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
@pytest.mark.parametrize("use_bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
|
||||||
use_bias: bool, device: str):
|
use_bias: bool, device: str):
|
||||||
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
|
cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape,
|
||||||
torch.bfloat16, device)
|
b_scale_group_shape, use_bias, torch.bfloat16,
|
||||||
|
device)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("a_scale_group_shape",
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
|
@pytest.mark.parametrize("b_scale_group_shape",
|
||||||
|
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
@pytest.mark.parametrize("use_bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
|
||||||
use_bias: bool, device: str):
|
use_bias: bool, device: str):
|
||||||
cutlass_int8_gemm_helper(512,
|
cutlass_int8_gemm_helper(512,
|
||||||
512,
|
512,
|
||||||
512,
|
512,
|
||||||
per_act_token,
|
a_scale_group_shape,
|
||||||
per_out_ch,
|
b_scale_group_shape,
|
||||||
use_bias,
|
use_bias,
|
||||||
out_dtype=torch.bfloat16,
|
out_dtype=torch.bfloat16,
|
||||||
device=device)
|
device=device)
|
||||||
@ -203,28 +275,32 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
|||||||
# of a large power of two. In any case, the kernel will have a naive fallback
|
# of a large power of two. In any case, the kernel will have a naive fallback
|
||||||
# when N and K are not divisible by 16. But M is the number of tokens and the
|
# when N and K are not divisible by 16. But M is the number of tokens and the
|
||||||
# kernel must handle any M thrown at it.
|
# kernel must handle any M thrown at it.
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("a_scale_group_shape",
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
|
@pytest.mark.parametrize("b_scale_group_shape",
|
||||||
|
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
@pytest.mark.parametrize("use_bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
|
||||||
use_bias: bool):
|
use_bias: bool):
|
||||||
for nk in range(32, 128, 32):
|
for nk in range(32, 128, 32):
|
||||||
for m in range(1, 128):
|
for m in range(1, 128):
|
||||||
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
|
cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape,
|
||||||
use_bias)
|
b_scale_group_shape, use_bias)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("a_scale_group_shape",
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
|
@pytest.mark.parametrize("b_scale_group_shape",
|
||||||
|
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||||
@pytest.mark.parametrize("use_bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
|
||||||
use_bias: bool):
|
use_bias: bool):
|
||||||
for nk in range(32, 128, 32):
|
for nk in range(32, 128, 32):
|
||||||
for m in range(1, 128):
|
for m in range(1, 128):
|
||||||
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
|
cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape,
|
||||||
use_bias)
|
b_scale_group_shape, use_bias)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||||
|
@ -1119,8 +1119,36 @@ def baseline_scaled_mm(a: torch.Tensor,
|
|||||||
scale_b: torch.Tensor,
|
scale_b: torch.Tensor,
|
||||||
out_dtype: Type[torch.dtype],
|
out_dtype: Type[torch.dtype],
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
output = (scale_a * (scale_b * (torch.mm(
|
|
||||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
# We treat N-dimensional group scaling as extended numpy-style broadcasting
|
||||||
|
# in numpy simply stretches dimensions with an extent of 1 to match the
|
||||||
|
# the target shape by repeating the data along that dimension (broadcasting)
|
||||||
|
# , we extend these semantics to say if the extent of a dimension in the
|
||||||
|
# source shape is not 1 and does not match the target shape we repeat each
|
||||||
|
# element along that dimension src_shape[dim] // target_shape[dim] times
|
||||||
|
# example if we have:
|
||||||
|
# a = [[1, 2], and target_shape = (2, 4)
|
||||||
|
# [3, 4]]
|
||||||
|
# then we would expand a to:
|
||||||
|
# a = [[1, 1, 2, 2],
|
||||||
|
# [3, 3, 4, 4]]
|
||||||
|
# NOTE this function this function does not explicitly broadcast dimensions
|
||||||
|
# with an extent of 1, since this can be done implicitly by pytorch
|
||||||
|
def group_broadcast(t, shape):
|
||||||
|
for i, s in enumerate(shape):
|
||||||
|
if t.shape[i] != s and t.shape[i] != 1:
|
||||||
|
assert s % t.shape[i] == 0
|
||||||
|
t = t.unsqueeze(i + 1)\
|
||||||
|
.expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
|
||||||
|
.flatten(i, i + 1)
|
||||||
|
return t
|
||||||
|
|
||||||
|
scale_a = group_broadcast(scale_a, a.shape)
|
||||||
|
scale_b = group_broadcast(scale_b, b.shape)
|
||||||
|
|
||||||
|
output = torch.mm((scale_a * a.to(dtype=torch.float32)),
|
||||||
|
(scale_b * b.to(dtype=torch.float32))).to(out_dtype)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output = output + bias
|
output = output + bias
|
||||||
|
|
||||||
|
@ -441,6 +441,28 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
|||||||
scale_b: torch.Tensor,
|
scale_b: torch.Tensor,
|
||||||
out_dtype: torch.dtype,
|
out_dtype: torch.dtype,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
`cutlass_scaled_mm` implements a fused version of
|
||||||
|
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||||
|
where scale_a * a and scale_b * b are implemented using numpy-style
|
||||||
|
broadcasting.
|
||||||
|
|
||||||
|
In order to support blockwise scaling like found in DeepSeek V3 we also
|
||||||
|
support extended "group" broadcast rules. We extend the numpy-style
|
||||||
|
broadcasting rules with the following rule:
|
||||||
|
"if the extent of a dimension in the source shape is between 1 and
|
||||||
|
corresponding extent in the target shape we repeat each element along
|
||||||
|
that dimension src_shape[dim] // target_shape[dim] times consecutively"
|
||||||
|
example if we have:
|
||||||
|
a = [[1, 2], and target_shape = (2, 4)
|
||||||
|
[3, 4]]
|
||||||
|
then we would expand a to:
|
||||||
|
a = [[1, 1, 2, 2],
|
||||||
|
[3, 3, 4, 4]]
|
||||||
|
currently we only support the case:
|
||||||
|
scale_a.shape * [1, 128] == a.shape
|
||||||
|
scale_b.shape * [128, 128] == b.shape
|
||||||
|
"""
|
||||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||||
assert bias is None or bias.shape[0] == b.shape[
|
assert bias is None or bias.shape[0] == b.shape[
|
||||||
|
Loading…
x
Reference in New Issue
Block a user