[Kernel] Initial Machete W4A8 support + Refactors (#9855)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
c2170a5b39
commit
96d999fbe8
@ -2,8 +2,10 @@ import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import pickle as pkl
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
from typing import Callable, Iterable, List, Optional, Tuple
|
||||
|
||||
@ -15,11 +17,12 @@ from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales)
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales,
|
||||
marlin_zero_points)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, pack_rows, quantize_weights)
|
||||
pack_rows, quantize_weights)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -27,149 +30,349 @@ DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False)
|
||||
|
||||
def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor:
|
||||
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||
w_q = w_q.t().contiguous().t() # make col major
|
||||
return ops.machete_prepack_B(w_q, wtype)
|
||||
if NVTX_PROFILE:
|
||||
import nvtx
|
||||
|
||||
|
||||
def make_bench_tensors(
|
||||
atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int,
|
||||
k: int
|
||||
) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor,
|
||||
torch.tensor]]]:
|
||||
def terse_type_name(dt):
|
||||
return {
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float16: "fp16",
|
||||
torch.int8: "int8",
|
||||
torch.float8_e4m3fn: "fp8",
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float: "float",
|
||||
torch.int: "int",
|
||||
}[dt]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkTensors:
|
||||
w_ref: torch.Tensor
|
||||
a: torch.Tensor
|
||||
|
||||
w_q: torch.Tensor
|
||||
group_size: Optional[int]
|
||||
wtype: ScalarType
|
||||
w_g_s: torch.Tensor
|
||||
w_g_zp: Optional[torch.Tensor]
|
||||
w_ch_s: Optional[torch.Tensor]
|
||||
w_tok_s: Optional[torch.Tensor]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeConfig:
|
||||
act_type: torch.dtype
|
||||
weight_type: ScalarType
|
||||
output_type: Optional[torch.dtype]
|
||||
group_scale_type: Optional[torch.dtype]
|
||||
group_zero_type: Optional[torch.dtype]
|
||||
channel_scale_type: Optional[torch.dtype]
|
||||
token_scale_type: Optional[torch.dtype]
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16, scale=1):
|
||||
if dtype.is_floating_point:
|
||||
return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype)
|
||||
else:
|
||||
return torch.randint(-15, 15, shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
def quantize_and_pack(atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
w,
|
||||
wtype,
|
||||
group_size=group_size,
|
||||
zero_points=zero_points,
|
||||
# to match how the kernel applies zps
|
||||
ref_zero_points_after_scales=True)
|
||||
|
||||
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||
return w_ref, w_q, w_s, w_zp
|
||||
|
||||
|
||||
def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig,
|
||||
group_size: Optional[int]) -> List[BenchmarkTensors]:
|
||||
m, n, k = shape
|
||||
|
||||
# we want to make sure that weights don't fit into L2 cache between runs so
|
||||
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
|
||||
# so we target total weight size > 2*50mb
|
||||
num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits))
|
||||
num_weights = math.ceil(2 * 50 * 1024**2 * 8 /
|
||||
(k * n * types.weight_type.size_bits))
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=atype) * 5
|
||||
weights = [
|
||||
torch.randn((k, n), device="cuda", dtype=atype)
|
||||
for _ in range(num_weights)
|
||||
]
|
||||
quanitized_weights = [
|
||||
quantize_weights(w, wtype, group_size) for w in weights
|
||||
]
|
||||
a = rand_data((m, k), types.act_type, scale=5)
|
||||
|
||||
return a, quanitized_weights
|
||||
benchmark_tensors: List[BenchmarkTensors] = []
|
||||
for _ in range(num_weights):
|
||||
w = rand_data((k, n), types.act_type, scale=5)
|
||||
|
||||
if types.group_scale_type is not None:
|
||||
w = w.to(types.group_scale_type)
|
||||
if w.dtype.itemsize == 1:
|
||||
w = w.to(torch.float16)
|
||||
|
||||
w_ref, w_q_packed, w_s, w_zp = quantize_and_pack(
|
||||
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
|
||||
types.group_zero_type is not None)
|
||||
|
||||
if not a.dtype.is_floating_point:
|
||||
aiinfo = torch.iinfo(a.dtype)
|
||||
w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max)
|
||||
|
||||
w_ref = w_ref.to(torch.float32)
|
||||
|
||||
w_ch_s = None if types.channel_scale_type is None else\
|
||||
rand_data((n,), types.channel_scale_type)
|
||||
w_tok_s = None if types.token_scale_type is None else\
|
||||
rand_data((m,), types.token_scale_type)
|
||||
|
||||
benchmark_tensors.append(
|
||||
BenchmarkTensors(w_ref=w_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
wtype=types.weight_type,
|
||||
w_g_s=w_s,
|
||||
w_g_zp=w_zp,
|
||||
group_size=group_size,
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s))
|
||||
|
||||
return benchmark_tensors
|
||||
|
||||
|
||||
def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
a = bt.a
|
||||
w = bt.w_ref.to(bt.a.dtype) # use float reference tensor
|
||||
if a.dtype not in [torch.float16, torch.bfloat16]:
|
||||
a = a.to(torch.float16)
|
||||
w = w.to(torch.float16)
|
||||
return lambda: torch.matmul(a, w)
|
||||
|
||||
|
||||
def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
if bt.w_ch_s is not None and bt.w_tok_s is not None:
|
||||
scale_a = bt.w_tok_s.to(torch.float32)
|
||||
scale_b = bt.w_ch_s.to(torch.float32)
|
||||
else:
|
||||
scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
|
||||
scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
|
||||
w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t()
|
||||
return lambda: ops.cutlass_scaled_mm(
|
||||
bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16)
|
||||
|
||||
|
||||
def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
device = bt.a.device
|
||||
|
||||
workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
if bt.w_g_zp is None:
|
||||
w_zp = torch.empty(0, dtype=torch.int, device=device)
|
||||
else:
|
||||
w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0],
|
||||
bt.w_ref.shape[1], bt.wtype.size_bits)
|
||||
|
||||
if bt.group_size is None:
|
||||
w_s = torch.tensor([], device="cuda", dtype=torch.half)
|
||||
else:
|
||||
w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0],
|
||||
bt.w_ref.shape[1], bt.group_size)
|
||||
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=device)
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=device)
|
||||
w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0],
|
||||
bt.w_ref.shape[1], bt.wtype.size_bits)
|
||||
|
||||
if bt.a.dtype.is_floating_point:
|
||||
assert bt.w_ch_s is None
|
||||
assert bt.w_tok_s is None
|
||||
assert bt.group_size is not None
|
||||
|
||||
fn = lambda: ops.gptq_marlin_gemm(a=bt.a,
|
||||
b_q_weight=w_q,
|
||||
b_scales=w_s,
|
||||
b_zeros=w_zp,
|
||||
g_idx=g_idx,
|
||||
perm=sort_indices,
|
||||
workspace=workspace.scratch,
|
||||
b_q_type=bt.wtype,
|
||||
size_m=bt.a.shape[0],
|
||||
size_n=bt.w_ref.shape[1],
|
||||
size_k=bt.w_ref.shape[0],
|
||||
is_k_full=True)
|
||||
else:
|
||||
assert bt.a.dtype == torch.int8
|
||||
assert bt.wtype == scalar_types.uint4b8
|
||||
|
||||
if bt.w_ch_s is not None:
|
||||
s_ch = bt.w_ch_s.to(torch.float32)
|
||||
else:
|
||||
s_ch = torch.ones(bt.w_ref.shape[1],
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
if bt.w_tok_s is not None:
|
||||
s_tok = bt.w_tok_s.to(torch.float32)
|
||||
else:
|
||||
s_tok = torch.ones(bt.a.shape[0],
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
fn = lambda: ops.marlin_qqq_gemm(a=bt.a,
|
||||
b_q_weight=w_q,
|
||||
s_group=w_s,
|
||||
s_tok=s_tok,
|
||||
s_ch=s_ch,
|
||||
workspace=workspace.scratch,
|
||||
size_m=bt.a.shape[0],
|
||||
size_n=bt.w_ref.shape[1],
|
||||
size_k=bt.w_ref.shape[0])
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def machete_create_bench_fn(bt: BenchmarkTensors,
|
||||
out_type=torch.dtype,
|
||||
schedule=None) -> Callable:
|
||||
w_q = bt.w_q.t().contiguous().t() # make col major
|
||||
w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype,
|
||||
None if bt.w_g_s is None else bt.w_g_s.dtype)
|
||||
|
||||
w_g_zp = bt.w_g_zp
|
||||
if w_g_zp is not None:
|
||||
w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype))
|
||||
|
||||
return lambda: ops.machete_mm(
|
||||
a=bt.a,
|
||||
b_q=bt.w_q,
|
||||
b_type=bt.wtype,
|
||||
b_group_scales=bt.w_g_s,
|
||||
b_group_zeros=w_g_zp,
|
||||
b_group_size=bt.group_size,
|
||||
b_channel_scales=bt.w_ch_s,
|
||||
a_token_scales=bt.w_tok_s,
|
||||
out_type=out_type,
|
||||
schedule=schedule,
|
||||
)
|
||||
|
||||
|
||||
# impl
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(label: str, sub_label: str, description: str,
|
||||
fn: Callable) -> TMeasurement:
|
||||
|
||||
min_run_time = 1
|
||||
return TBenchmark.Timer(
|
||||
stmt="fn()",
|
||||
|
||||
def bench_fns(label: str, sub_label: str, description: str,
|
||||
fns: List[Callable]):
|
||||
|
||||
min_run_time = 1 if not NVTX_PROFILE else 0.1
|
||||
res = TBenchmark.Timer(
|
||||
stmt="""
|
||||
for fn in fns:
|
||||
fn()
|
||||
""",
|
||||
globals={
|
||||
"fn": fn
|
||||
"fns": fns
|
||||
},
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=description,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
if NVTX_PROFILE:
|
||||
with nvtx.annotate("mm-bench"), nvtx.annotate(
|
||||
f"{label}|{sub_label}|{description}"):
|
||||
fns[0]()
|
||||
|
||||
def loop_over_weights(
|
||||
a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor,
|
||||
torch.tensor, torch.tensor]],
|
||||
fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor],
|
||||
None]):
|
||||
for w_ref, w_q, w_s, _ in weights:
|
||||
fn(a, w_ref, w_q, w_s)
|
||||
return res
|
||||
|
||||
|
||||
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
|
||||
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
|
||||
|
||||
|
||||
def bench(atype: torch.dtype,
|
||||
wtype: ScalarType,
|
||||
def bench(types: TypeConfig,
|
||||
group_size: int,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
benchmark_marlinv1: bool = True,
|
||||
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
|
||||
global _SWEEP_SCHEDULES_RESULTS
|
||||
sweep_schedules: bool = True) -> List[TMeasurement]:
|
||||
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
|
||||
sub_label += f", L={len(benchmark_tensors)}"
|
||||
|
||||
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
|
||||
sub_label += f", L={len(weights)}"
|
||||
|
||||
weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp)
|
||||
for w_ref, w_q, w_s, w_zp in weights]
|
||||
name_type_string = f"W{types.weight_type}"+\
|
||||
f"-A{terse_type_name(types.act_type)}"
|
||||
if types.group_scale_type is not None:
|
||||
name_type_string += f"-GS{terse_type_name(types.group_scale_type)}"
|
||||
if types.group_zero_type is not None:
|
||||
name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}"
|
||||
if group_size is not None:
|
||||
name_type_string += f"-G{group_size}"
|
||||
if types.channel_scale_type is not None:
|
||||
name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}"
|
||||
if types.token_scale_type is not None:
|
||||
name_type_string += f"-TS{terse_type_name(types.token_scale_type)}"
|
||||
|
||||
timers = []
|
||||
# pytorch impl
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label, sub_label, "torch.matmul", lambda: loop_over_weights(
|
||||
a,
|
||||
weights,
|
||||
lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref),
|
||||
)))
|
||||
bench_fns(
|
||||
label, sub_label, "torch.matmul (fp16)",
|
||||
[torch_matmul_f16_create_bench_fn(bt)
|
||||
for bt in benchmark_tensors]))
|
||||
|
||||
if benchmark_marlinv1:
|
||||
w_ref = weights[0][0]
|
||||
|
||||
w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device)
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device)
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device)
|
||||
|
||||
def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor:
|
||||
w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape)
|
||||
return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape,
|
||||
wtype.size_bits)
|
||||
|
||||
def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
|
||||
return marlin_permute_scales(w_s, *w_ref.shape, group_size)
|
||||
|
||||
weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q),
|
||||
marlinv1_permute_scales(w_s), w_zp)
|
||||
for w_ref, w_q, w_s, w_zp in weights]
|
||||
|
||||
workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
# marlinv1
|
||||
if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label, sub_label, "marlin_orig", lambda: loop_over_weights(
|
||||
a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops.
|
||||
gptq_marlin_gemm(a,
|
||||
w_q,
|
||||
w_s,
|
||||
w_zp_empty,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
wtype,
|
||||
size_m=a.shape[0],
|
||||
size_n=w_ref.shape[1],
|
||||
size_k=w_ref.shape[0],
|
||||
is_k_full=True))))
|
||||
bench_fns(
|
||||
label, sub_label,
|
||||
f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [
|
||||
cutlass_scaled_mm_create_bench_fn(bt)
|
||||
for bt in benchmark_tensors
|
||||
]))
|
||||
|
||||
if types.act_type != torch.float8_e4m3fn:
|
||||
timers.append(
|
||||
bench_fns(label, sub_label, f"marlin ({name_type_string})",
|
||||
[marlin_create_bench_fn(bt)
|
||||
for bt in benchmark_tensors]))
|
||||
|
||||
# machete
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label, sub_label, "machete_heuristic", lambda: loop_over_weights(
|
||||
a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm(
|
||||
a, w_q, wtype, b_scales=w_s, b_group_size=group_size))))
|
||||
bench_fns(label, sub_label, f"machete ({name_type_string})", [
|
||||
machete_create_bench_fn(bt, out_type=types.output_type)
|
||||
for bt in benchmark_tensors
|
||||
]))
|
||||
|
||||
if sweep_schedules:
|
||||
global _SWEEP_SCHEDULES_RESULTS
|
||||
|
||||
print("Finding best schedule for machete")
|
||||
best = None
|
||||
best_schedule = None
|
||||
schedules = ops.machete_supported_schedules(wtype)
|
||||
schedules = ops.machete_supported_schedules(
|
||||
a_type=types.act_type,
|
||||
b_type=types.weight_type,
|
||||
group_scales_type=types.group_scale_type,
|
||||
group_zeros_type=types.group_zero_type,
|
||||
token_scales_type=types.token_scale_type,
|
||||
channel_scales_type=types.channel_scale_type,
|
||||
out_type=types.output_type)
|
||||
|
||||
if schedules is None or len(schedules) == 0:
|
||||
raise ValueError("No schedules found to sweep")
|
||||
|
||||
for schedule in reversed(schedules):
|
||||
schedule_M = int(schedule.split("_")[0].split("x")[1])
|
||||
|
||||
@ -177,16 +380,11 @@ def bench(atype: torch.dtype,
|
||||
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
|
||||
continue
|
||||
|
||||
def run(a, _, w_q, w_s, schedule=schedule):
|
||||
ops.machete_gemm(a,
|
||||
w_q,
|
||||
wtype,
|
||||
w_s,
|
||||
b_group_size=group_size,
|
||||
schedule=schedule)
|
||||
|
||||
res = bench_fn(label, sub_label, "machete_best",
|
||||
lambda: loop_over_weights(a, weights_machete, run))
|
||||
res = bench_fns(label, sub_label, "machete_best", [
|
||||
machete_create_bench_fn(
|
||||
bt, out_type=types.output_type, schedule=schedule)
|
||||
for bt in benchmark_tensors
|
||||
])
|
||||
|
||||
results_row = {
|
||||
"M": m,
|
||||
@ -213,25 +411,33 @@ def bench(atype: torch.dtype,
|
||||
|
||||
|
||||
# runner
|
||||
def print_timers(timers: Iterable[TMeasurement]):
|
||||
def print_timers(timers: List[TMeasurement]):
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
def run(dtype: torch.dtype, sweep_schedules: bool,
|
||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
types = TypeConfig(
|
||||
act_type=args.act_type,
|
||||
weight_type=scalar_types.uint4b8 if args.group_zero_type is None \
|
||||
else scalar_types.uint4,
|
||||
output_type=args.out_type,
|
||||
group_scale_type=args.group_scale_type,
|
||||
group_zero_type=args.group_zero_type,
|
||||
channel_scale_type=args.channel_scale_type,
|
||||
token_scale_type=args.token_scale_type,
|
||||
)
|
||||
|
||||
results = []
|
||||
results: List[TMeasurement] = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype,
|
||||
scalar_types.uint4b8,
|
||||
128,
|
||||
timers = bench(types,
|
||||
args.group_size,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
f"{dtype}-gemm",
|
||||
f"{args.act_type}-gemm",
|
||||
f"MKN=({m}x{k}x{n})",
|
||||
sweep_schedules=sweep_schedules)
|
||||
sweep_schedules=args.sweep_schedules)
|
||||
print_timers(timers)
|
||||
results.extend(timers)
|
||||
|
||||
@ -240,7 +446,7 @@ def run(dtype: torch.dtype, sweep_schedules: bool,
|
||||
|
||||
# output makers
|
||||
def make_output(
|
||||
data: Iterable[TMeasurement],
|
||||
data: List[TMeasurement],
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None,
|
||||
@ -262,7 +468,6 @@ def run_square_bench(args):
|
||||
dim_sizes = list(
|
||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
|
||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||
@ -306,33 +511,49 @@ def run_model_bench(args):
|
||||
for k, n in KNs:
|
||||
MKNs.append((m, k, n))
|
||||
|
||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||
data = run(args, MKNs)
|
||||
model_bench_data.append(data)
|
||||
|
||||
type_string = f"{args.act_type}"
|
||||
|
||||
# Print all results
|
||||
for data, model_tp in zip(model_bench_data, models_tps):
|
||||
model, tp_size = model_tp
|
||||
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
|
||||
print(f"== Results {type_string} {model}-TP{tp_size} ====")
|
||||
print_timers(data)
|
||||
|
||||
timestamp = int(time.time())
|
||||
timestr = time.strftime("%Y%m%d-%H%M%S")
|
||||
|
||||
all_data = []
|
||||
all_results = []
|
||||
for d in model_bench_data:
|
||||
all_data.extend(d)
|
||||
all_results.extend(d)
|
||||
|
||||
# pickle all data
|
||||
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
|
||||
pkl.dump(all_data, f)
|
||||
with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f:
|
||||
args_dict = vars(args)
|
||||
args_dict.pop("func")
|
||||
pkl.dump({
|
||||
"args": args_dict,
|
||||
"results": all_results,
|
||||
}, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
if dt == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if dt == "float16":
|
||||
return torch.float16
|
||||
raise ValueError("unsupported dtype")
|
||||
return {
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float16": torch.float16,
|
||||
"int8": torch.int8,
|
||||
"float8_e4m3fn": torch.float8_e4m3fn,
|
||||
"int": torch.int,
|
||||
"float": torch.float,
|
||||
}[dt]
|
||||
|
||||
class ToTorchDtype(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, to_torch_dtype(values))
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="""
|
||||
@ -352,12 +573,42 @@ Benchmark Machete GEMM.
|
||||
""", # noqa: E501
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=to_torch_dtype,
|
||||
"--act-type",
|
||||
action=ToTorchDtype,
|
||||
required=True,
|
||||
help="Available options are ['bfloat16', 'float16']",
|
||||
choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-scale-type",
|
||||
action=ToTorchDtype,
|
||||
choices=['bfloat16', 'float16'],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-zero-type",
|
||||
type=to_torch_dtype,
|
||||
choices=['bfloat16', 'float16'],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel-scale-type",
|
||||
action=ToTorchDtype,
|
||||
choices=['float'],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token-scale-type",
|
||||
action=ToTorchDtype,
|
||||
choices=['float'],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-type",
|
||||
action=ToTorchDtype,
|
||||
choices=['bfloat16', 'float16'],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
help="Available options are ['None', '-1', '128'], default=128",
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sweep-schedules",
|
||||
|
@ -20,10 +20,11 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.filename, 'rb') as f:
|
||||
data: List[TMeasurement] = pickle.load(f)
|
||||
data = pickle.load(f)
|
||||
raw_results: List[TMeasurement] = data["results"]
|
||||
|
||||
results = defaultdict(lambda: list())
|
||||
for v in data:
|
||||
for v in raw_results:
|
||||
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
|
||||
if result is not None:
|
||||
KN = result.group(1)
|
||||
|
@ -40,4 +40,10 @@ WEIGHT_SHAPES = {
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.1-405b-hf": [
|
||||
([16384, 18432], 1),
|
||||
([16384, 16384], 0),
|
||||
([16384, 106496], 1),
|
||||
([53248, 16384], 0),
|
||||
],
|
||||
}
|
||||
|
@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
|
||||
// is the layout f(x) = x
|
||||
template <typename Layout>
|
||||
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
|
||||
if constexpr (std::is_same_v<Layout, void>)
|
||||
if constexpr (std::is_same_v<Layout, void>) {
|
||||
return true;
|
||||
else {
|
||||
} else {
|
||||
constexpr auto coalesced_layout = coalesce(Layout{});
|
||||
if constexpr (rank(coalesced_layout) == 1 &&
|
||||
stride<0>(coalesced_layout) == 1) {
|
||||
|
@ -52,6 +52,7 @@
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::epilogue::threadblock {
|
317
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
Normal file
317
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
Normal file
@ -0,0 +1,317 @@
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
|
||||
|
||||
/*
|
||||
This file defines custom epilogues for fusing channel scales, token scales,
|
||||
bias, and activation zero-points onto a GEMM operation using the
|
||||
CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
|
||||
|
||||
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
namespace vllm::c2x {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrZeroLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
} else {
|
||||
// it would technically work but no use case as data_ptr is never nullptr
|
||||
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
}
|
||||
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
torch._scaled_mm.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||
* This bias can also be used in the per-tensor azp case, where the activation
|
||||
* zero point (azp) is used to compute an azp correction term,
|
||||
* which is folded into the bias.
|
||||
*
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBias
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
protected:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||
EVTCompute0, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
* term, which should already be multiplied with the scalar azp.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue supports per-token azp by computing and applying
|
||||
* the correction term using a rank-1 update. If the term were materialized,
|
||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||
* point for each row of A.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
// Per-token azp term, shape (m,1)
|
||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||
|
||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace vllm::c2x
|
315
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
Normal file
315
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
Normal file
@ -0,0 +1,315 @@
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
|
||||
/*
|
||||
This file defines custom epilogues for fusing channel scales, token scales,
|
||||
bias, and activation zero-points onto a GEMM operation using the
|
||||
CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
|
||||
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
namespace vllm::c3x {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
} else {
|
||||
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
|
||||
!std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
}
|
||||
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
torch.scaled_mm_.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be
|
||||
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||
* This bias can also be used in the per-tensor azp case, where the activation
|
||||
* zero point (azp) is used to compute an azp correction term,
|
||||
* which is folded into the bias.
|
||||
*
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBias
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
* term, which should already be multiplied with the scalar azp.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||
|
||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue supports per-token azp by computing and applying
|
||||
* the correction term using a rank-1 update. If the term were materialized,
|
||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||
* point for each row of A.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||
|
||||
// Per-token azp term, shape (m,1)
|
||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||
|
||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace vllm::c3x
|
@ -35,6 +35,35 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
}
|
||||
}
|
||||
|
||||
VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
|
||||
**DataTypeSize, # type: ignore
|
||||
**{
|
||||
VLLMDataType.u4b8: 4,
|
||||
VLLMDataType.u8b128: 8,
|
||||
}
|
||||
}
|
||||
|
||||
VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
VLLMDataType.u4b8: "vllm::kU4B8",
|
||||
VLLMDataType.u8b128: "vllm::kU8B128",
|
||||
DataType.u4: "vllm::kU4",
|
||||
DataType.u8: "vllm::kU8",
|
||||
DataType.s4: "vllm::kS4",
|
||||
DataType.s8: "vllm::kS8",
|
||||
DataType.f16: "vllm::kFloat16",
|
||||
DataType.bf16: "vllm::kBfloat16",
|
||||
}
|
||||
|
||||
VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
DataType.u8: "at::ScalarType::Byte",
|
||||
DataType.s8: "at::ScalarType::Char",
|
||||
DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
|
||||
DataType.s32: "at::ScalarType::Int",
|
||||
DataType.f16: "at::ScalarType::Half",
|
||||
DataType.bf16: "at::ScalarType::BFloat16",
|
||||
DataType.f32: "at::ScalarType::Float",
|
||||
}
|
||||
|
||||
VLLMKernelScheduleTag: Dict[Union[
|
||||
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
||||
**KernelScheduleTag, # type: ignore
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass_extensions/vllm_custom_types.cuh"
|
||||
#include "cutlass_extensions/cute_utils.cuh"
|
||||
#include "cutlass_extensions/vllm_type_utils.cuh"
|
||||
|
||||
// this file extends:
|
||||
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
|
||||
@ -28,8 +29,19 @@ struct InterleavedNumericArrayConverter {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
CUTE_INVALID_CONTROL_PATH(
|
||||
"InterleavedNumericArrayConverter not implemented\n");
|
||||
if (cute::elect_one_sync()) {
|
||||
if constexpr (std::is_same_v<IlvBlkLayout, void>) {
|
||||
printf(
|
||||
"Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n",
|
||||
nameof_v<T>, nameof_v<S>, N);
|
||||
} else {
|
||||
printf(
|
||||
"Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not "
|
||||
"implemented\n",
|
||||
nameof_v<T>, nameof_v<S>, N, size(IlvBlkLayout{}));
|
||||
}
|
||||
__brkpt();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
@ -56,11 +68,6 @@ struct InterleavedNumericArrayConverter<
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// TODO (LucasWilkinson): Implement
|
||||
// for Array<cutlass::float8_e4m3fn, N> <= Array<vllm_uint4b8_t, N>
|
||||
|
||||
// ....
|
||||
|
||||
template <typename RegConvert32bit, typename T, typename S, int N>
|
||||
struct ArrayConverterPacked32Bit {
|
||||
using result_type = Array<T, N>;
|
||||
@ -86,14 +93,16 @@ struct ArrayConverterPacked32Bit {
|
||||
using ScalarConverter = NumericConverter<T, S>;
|
||||
|
||||
template <typename PackedSrc>
|
||||
CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) {
|
||||
CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) {
|
||||
if constexpr (sizeof(PackedSrc) == 1) {
|
||||
return static_cast<uint32_t>(reinterpret_cast<const uint8_t&>(source));
|
||||
return Array<uint32_t, 1>{reinterpret_cast<uint8_t const&>(src)};
|
||||
} else if constexpr (sizeof(PackedSrc) == 2) {
|
||||
return static_cast<uint32_t>(reinterpret_cast<const uint16_t&>(source));
|
||||
return Array<uint32_t, 1>{reinterpret_cast<uint16_t const&>(src)};
|
||||
} else if constexpr (sizeof(PackedSrc) == 4) {
|
||||
return Array<uint32_t, 1>{reinterpret_cast<uint32_t const&>(src)};
|
||||
} else {
|
||||
static_assert(sizeof(PackedSrc) == 4);
|
||||
return reinterpret_cast<const uint32_t&>(source);
|
||||
static_assert(sizeof(PackedSrc) == 8);
|
||||
return reinterpret_cast<Array<uint32_t, 2> const&>(src);
|
||||
}
|
||||
}
|
||||
|
||||
@ -110,7 +119,7 @@ struct ArrayConverterPacked32Bit {
|
||||
static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
|
||||
static_assert(std::is_same_v<typename PackedResultType::Element, T>);
|
||||
|
||||
return RegConvert32bit::template convert<PackedResultType>(to_reg(source));
|
||||
return RegConvert32bit::template convert<PackedResultType>(to_regs(source));
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
@ -140,6 +149,131 @@ struct ArrayConverterPacked32Bit {
|
||||
}
|
||||
};
|
||||
|
||||
// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed
|
||||
// into 2 32bit register.
|
||||
template <uint8_t LUT0, uint8_t LUT1, uint8_t LUT2, uint8_t LUT3, //
|
||||
uint8_t LUT4, uint8_t LUT5, uint8_t LUT6, uint8_t LUT7, //
|
||||
uint8_t LUT8, uint8_t LUT9, uint8_t LUT10, uint8_t LUT11, //
|
||||
uint8_t LUT12, uint8_t LUT13, uint8_t LUT14, uint8_t LUT15>
|
||||
CUTLASS_DEVICE cutlass::AlignedArray<uint32_t, 2> lut_4bit_to_8bit_convert(
|
||||
uint32_t src) {
|
||||
cutlass::AlignedArray<uint32_t, 2> r;
|
||||
// Determines if the value is in the top half of the LUT if set or
|
||||
// (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move
|
||||
// into bit position 0x4 of each nibble so when or'd with final_prmt_base it
|
||||
// selects the correct candidate. When elements in final_prmt_base
|
||||
// are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements
|
||||
// are < 0x4, the low candidate is selected (i.e. LUT[0:7])
|
||||
uint32_t high_bit = (src & 0x88888888) >> 1;
|
||||
|
||||
// `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT
|
||||
// (selects correct high or low candidate)
|
||||
const uint32_t final_prmt_base = 0x32103210;
|
||||
|
||||
// Ignore the high bit when indexing into LUT, for each 4bit value
|
||||
// we index into both the high and low candidates then use
|
||||
// high_bit | final_prmt_base to select the correct candidate
|
||||
uint32_t lut_idx = (src & 0x77777777);
|
||||
|
||||
auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) {
|
||||
return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) |
|
||||
(uint32_t(d) << 24);
|
||||
};
|
||||
|
||||
static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3);
|
||||
static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7);
|
||||
static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11);
|
||||
static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) {
|
||||
uint32_t final_prmt_idx = final_prmt_base | high_bit;
|
||||
|
||||
// This uses a look up table to convert packed int4s to packed int8s,
|
||||
// using the int4 value as the index to prmt. It first select both the
|
||||
// high and low candidates, then uses the high bit (i.e. `high_bit`) to
|
||||
// select the correct candidate.
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .b32 low, high;\n"
|
||||
" prmt.b32 low, %1, %2, %5;\n"
|
||||
" prmt.b32 high, %3, %4, %5;\n"
|
||||
" prmt.b32 %0, low, high, %6;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx),
|
||||
"r"(final_prmt_idx));
|
||||
}
|
||||
|
||||
return r;
|
||||
};
|
||||
|
||||
// for Array<int8_t, N> <= Array<vllm_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<int8_t, vllm_uint4b8_t, N, Round> {
|
||||
using result_type = Array<int8_t, N>;
|
||||
using source_type = Array<vllm_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s
|
||||
auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, //
|
||||
0xFC, 0xFD, 0xFE, 0xFF, //
|
||||
0x00, 0x01, 0x02, 0x03, //
|
||||
0x04, 0x05, 0x06, 0x07>(src_[0]);
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::float_e4m3_t, N> <= Array<vllm_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::float_e4m3_t, vllm_uint4b8_t, N, Round> {
|
||||
using result_type = Array<cutlass::float_e4m3_t, N>;
|
||||
using source_type = Array<vllm_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s
|
||||
auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, //
|
||||
0xC8, 0xC4, 0xC0, 0xB8, //
|
||||
0x00, 0x38, 0x40, 0x44, //
|
||||
0x48, 0x4A, 0x4C, 0x4E>(src_[0]);
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
|
||||
@ -148,7 +282,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
|
||||
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
@ -249,7 +384,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
@ -338,7 +474,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
@ -417,7 +554,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint8b128_t, N, Round> {
|
||||
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
// Hold output FP16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
@ -469,7 +607,8 @@ struct NumericArrayConverter<float, vllm_uint8b128_t, N, Round> {
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
PackedResultType r;
|
||||
|
||||
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
|
||||
@ -513,7 +652,8 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> {
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) {
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src_reg = src_[0];
|
||||
// Hold output BF16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
@ -603,7 +743,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
@ -671,7 +812,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
@ -788,6 +930,61 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint8b128_t, N, Round> {
|
||||
|
||||
#endif
|
||||
|
||||
// for Array<int8_t, N> <= Array<cutlass::half_t, N>
|
||||
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<int8_t, cutlass::half_t, N, Round> {
|
||||
using result_type = Array<int8_t, N>;
|
||||
using source_type = Array<cutlass::half_t, N>;
|
||||
|
||||
struct RegConvert {
|
||||
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
|
||||
template <typename PackedResultType, int src_regs>
|
||||
CUTLASS_DEVICE static PackedResultType convert(
|
||||
Array<uint32_t, src_regs> src) {
|
||||
// Hold output int8s in reg. We need 1 reg for every 4 elements
|
||||
using RegArray = cutlass::AlignedArray<
|
||||
uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>;
|
||||
RegArray r;
|
||||
|
||||
static constexpr uint32_t MAGIC_BIAS_ = 0x64806480;
|
||||
auto MAGIC_BIAS = *reinterpret_cast<const half2*>(&MAGIC_BIAS_);
|
||||
|
||||
*reinterpret_cast<half2*>(&src[0]) =
|
||||
__hadd2(*reinterpret_cast<half2*>(&src[0]), MAGIC_BIAS);
|
||||
|
||||
if constexpr (src_regs > 1) {
|
||||
*reinterpret_cast<half2*>(&src[1]) =
|
||||
__hadd2(*reinterpret_cast<half2*>(&src[1]), MAGIC_BIAS);
|
||||
}
|
||||
|
||||
static_assert(PackedResultType::kElements <= 4);
|
||||
uint32_t uint8s;
|
||||
static constexpr uint32_t MASK_0246 = 0x6420;
|
||||
static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
|
||||
: "=r"(uint8s)
|
||||
: "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]),
|
||||
"n"(MASK_0246));
|
||||
|
||||
uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK);
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(int8s);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
42
csrc/cutlass_extensions/vllm_type_utils.cuh
Normal file
42
csrc/cutlass_extensions/vllm_type_utils.cuh
Normal file
@ -0,0 +1,42 @@
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cuda_bf16.h"
|
||||
|
||||
#include "cutlass_extensions/vllm_custom_types.cuh"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
template <typename T>
|
||||
struct nameof {
|
||||
static constexpr char const* value = "unknown";
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline constexpr auto nameof_v = nameof<T>::value;
|
||||
|
||||
#define NAMEOF_TYPE(T) \
|
||||
template <> \
|
||||
struct nameof<T> { \
|
||||
static constexpr char const* value = #T; \
|
||||
};
|
||||
|
||||
NAMEOF_TYPE(float_e4m3_t)
|
||||
NAMEOF_TYPE(float_e5m2_t)
|
||||
NAMEOF_TYPE(half_t)
|
||||
NAMEOF_TYPE(nv_bfloat16)
|
||||
NAMEOF_TYPE(bfloat16_t)
|
||||
NAMEOF_TYPE(float)
|
||||
|
||||
NAMEOF_TYPE(int4b_t)
|
||||
NAMEOF_TYPE(int8_t)
|
||||
NAMEOF_TYPE(int32_t)
|
||||
NAMEOF_TYPE(int64_t)
|
||||
|
||||
NAMEOF_TYPE(vllm_uint4b8_t)
|
||||
NAMEOF_TYPE(uint4b_t)
|
||||
NAMEOF_TYPE(uint8_t)
|
||||
NAMEOF_TYPE(vllm_uint8b128_t)
|
||||
NAMEOF_TYPE(uint32_t)
|
||||
NAMEOF_TYPE(uint64_t)
|
||||
|
||||
}; // namespace cutlass
|
@ -8,6 +8,10 @@
|
||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
|
||||
|
||||
using namespace vllm;
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||
@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBias>(
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogue>(
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBias>(
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogue>(
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
assert(out.dtype() == torch::kFloat16);
|
||||
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
|
||||
Epilogue>(
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else {
|
||||
@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return vllm::cutlass_gemm_sm89_fp8_dispatch<
|
||||
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBias>(
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogue>(
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
@ -21,7 +21,6 @@
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
|
||||
#include "broadcast_load_epilogue_c2x.hpp"
|
||||
#include "common.hpp"
|
||||
// clang-format on
|
||||
|
||||
@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel {
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrZeroLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
} else {
|
||||
// it would technically work but no use case as data_ptr is never nullptr
|
||||
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
}
|
||||
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
torch._scaled_mm.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||
* This bias can also be used in the per-tensor azp case, where the activation
|
||||
* zero point (azp) is used to compute an azp correction term,
|
||||
* which is folded into the bias.
|
||||
*
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBias
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
protected:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||
EVTCompute0, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
* term, which should already be multiplied with the scalar azp.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue supports per-token azp by computing and applying
|
||||
* the correction term using a rank-1 update. If the term were materialized,
|
||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||
* point for each row of A.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
// Per-token azp term, shape (m,1)
|
||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||
|
||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Arch, template <typename> typename ArchGuard,
|
||||
typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename> typename Epilogue_, typename TileShape,
|
||||
|
@ -23,11 +23,12 @@
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
using namespace vllm;
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel {
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
} else {
|
||||
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
|
||||
!std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
}
|
||||
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
torch.scaled_mm_.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be
|
||||
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||
* This bias can also be used in the per-tensor azp case, where the activation
|
||||
* zero point (azp) is used to compute an azp correction term,
|
||||
* which is folded into the bias.
|
||||
*
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBias
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
* term, which should already be multiplied with the scalar azp.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||
|
||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue supports per-token azp by computing and applying
|
||||
* the correction term using a rank-1 update. If the term were materialized,
|
||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||
* point for each row of A.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||
|
||||
// Per-token azp term, shape (m,1)
|
||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||
|
||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == c.dtype(),
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBias>(
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
|
||||
c, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogue>(c, a, b, a_scales,
|
||||
b_scales);
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
|
||||
c, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
@ -3,8 +3,10 @@ import math
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import jinja2
|
||||
# yapf conflicts with isort for this block
|
||||
@ -14,7 +16,10 @@ from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag,
|
||||
MixedInputKernelScheduleType,
|
||||
TileSchedulerTag,
|
||||
TileSchedulerType, VLLMDataType,
|
||||
VLLMDataTypeNames, VLLMDataTypeTag,
|
||||
VLLMDataTypeNames,
|
||||
VLLMDataTypeSize, VLLMDataTypeTag,
|
||||
VLLMDataTypeTorchDataTypeTag,
|
||||
VLLMDataTypeVLLMScalarTypeTag,
|
||||
VLLMKernelScheduleTag)
|
||||
|
||||
# yapf: enable
|
||||
@ -27,49 +32,125 @@ DISPATCH_TEMPLATE = """
|
||||
#include "../machete_mm_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
using GemmDispatcher_ = GemmDispatcher<
|
||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
|
||||
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
|
||||
|
||||
{% for s in schedules %}extern torch::Tensor
|
||||
impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args);
|
||||
{% endfor %}
|
||||
template <>
|
||||
torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) {
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set type_sig = gen_type_sig(impl_config.types) -%}
|
||||
{% for s in impl_config.schedules %}
|
||||
extern torch::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs);
|
||||
{%- endfor %}
|
||||
|
||||
torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
|
||||
[[maybe_unused]] auto M = args.A.size(0);
|
||||
[[maybe_unused]] auto N = args.B.size(1);
|
||||
[[maybe_unused]] auto K = args.A.size(1);
|
||||
|
||||
if (!args.schedule) {
|
||||
{%- for cond, s in heuristic %}
|
||||
if (!args.maybe_schedule) {
|
||||
{%- for cond, s in impl_config.heuristic %}
|
||||
{%if cond is not none%}if ({{cond}})
|
||||
{%- else %}else
|
||||
{%- endif %}
|
||||
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %}
|
||||
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %}
|
||||
}
|
||||
|
||||
{% for s in schedules %}
|
||||
if (*args.schedule == "{{ gen_sch_name(s) }}") {
|
||||
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);
|
||||
}
|
||||
{% endfor %}
|
||||
{%- for s in impl_config.schedules %}
|
||||
if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}")
|
||||
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);
|
||||
{%- endfor %}
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
|
||||
"schedule = ", *args.schedule);
|
||||
"schedule = ", *args.maybe_schedule);
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
|
||||
static inline std::optional<at::ScalarType> maybe_scalartype(
|
||||
c10::optional<at::Tensor> const& t) {
|
||||
if (!t) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return t->scalar_type();
|
||||
};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<std::string> GemmDispatcher_::supported_schedules() {
|
||||
return {
|
||||
{% for s in schedules -%}
|
||||
"{{ gen_sch_name(s) }}"{{ ",
|
||||
" if not loop.last }}{%- endfor %}
|
||||
};
|
||||
torch::Tensor mm_dispatch(MMArgs args) {
|
||||
auto out_type = args.maybe_out_type.value_or(args.A.scalar_type());
|
||||
auto a_type = args.A.scalar_type();
|
||||
auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales);
|
||||
auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros);
|
||||
auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales);
|
||||
auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales);
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set type_sig = gen_type_sig(t) -%}
|
||||
if (args.b_type == {{VLLMScalarTypeTag[t.b]}}
|
||||
&& a_type == {{TorchTypeTag[t.a]}}
|
||||
&& out_type == {{TorchTypeTag[t.out]}}
|
||||
&& {%if t.b_group_scale != void -%}
|
||||
maybe_g_scales_type == {{TorchTypeTag[t.b_group_scale]}}
|
||||
{%- else %}!maybe_g_scales_type{%endif%}
|
||||
&& {%if t.b_group_zeropoint != void -%}
|
||||
maybe_g_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}}
|
||||
{%- else %}!maybe_g_zeros_type{%endif%}
|
||||
&& {%if t.b_channel_scale != void -%}
|
||||
maybe_ch_scales_type == {{TorchTypeTag[t.b_channel_scale]}}
|
||||
{%- else %}!maybe_ch_scales_type{%endif%}
|
||||
&& {%if t.a_token_scale != void -%}
|
||||
maybe_tok_scales_type == {{TorchTypeTag[t.a_token_scale]}}
|
||||
{%- else %}!maybe_tok_scales_type{%endif%}
|
||||
) {
|
||||
return mm_dispatch_{{type_sig}}(args);
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "machete_mm(..) is not implemented for "
|
||||
"a_type=", args.A.scalar_type(),
|
||||
", b_type=", args.b_type.str(),
|
||||
", out_type=", out_type,
|
||||
", with_group_scale_type=", maybe_g_scales_type
|
||||
? toString(*maybe_g_scales_type) : "None",
|
||||
", with_group_zeropoint_type=", maybe_g_zeros_type
|
||||
? toString(*maybe_g_zeros_type) : "None",
|
||||
", with_channel_scale_type=", maybe_ch_scales_type
|
||||
? toString(*maybe_ch_scales_type) : "None",
|
||||
", with_token_scale_type=", maybe_tok_scales_type
|
||||
? toString(*maybe_tok_scales_type) : "None",
|
||||
"; implemented types are: \\n",
|
||||
{%- for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
"\\t{{gen_type_option_name(t)}}\\n",
|
||||
{%- endfor %}
|
||||
"");
|
||||
}
|
||||
|
||||
std::vector<std::string> supported_schedules_dispatch(
|
||||
SupportedSchedulesArgs args) {
|
||||
auto out_type = args.maybe_out_type.value_or(args.a_type);
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set schs = impl_config.schedules -%}
|
||||
if (args.b_type == {{VLLMScalarTypeTag[t.b]}}
|
||||
&& args.a_type == {{TorchTypeTag[t.a]}}
|
||||
&& out_type == {{TorchTypeTag[t.out]}}
|
||||
&& {%if t.b_group_scale != void -%}
|
||||
args.maybe_group_scales_type == {{TorchTypeTag[t.b_group_scale]}}
|
||||
{%- else %}!args.maybe_group_scales_type{%endif%}
|
||||
&& {%if t.b_group_zeropoint != void-%}
|
||||
args.maybe_group_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}}
|
||||
{%- else %}!args.maybe_group_zeros_type{%endif%}
|
||||
) {
|
||||
return {
|
||||
{%- for s in impl_config.schedules %}
|
||||
"{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
};
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
return {};
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
@ -77,20 +158,10 @@ IMPL_TEMPLATE = """
|
||||
#include "../machete_mm_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
template <typename Config, bool with_C, bool with_scales, bool with_zeropoints>
|
||||
using Kernel = MacheteKernelTemplate<
|
||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
|
||||
{{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
|
||||
Config, with_C, with_scales, with_zeropoints>;
|
||||
|
||||
{% for sch in schedules %}
|
||||
{% set schedule_name = gen_sch_name(sch) -%}
|
||||
struct sch_{{schedule_name}} {
|
||||
|
||||
{% for sch in unique_schedules(impl_configs) %}
|
||||
{% set sch_sig = gen_sch_sig(sch) -%}
|
||||
struct sch_{{sch_sig}} {
|
||||
using TileShapeNM = Shape<{{
|
||||
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
|
||||
using ClusterShape = Shape<{{
|
||||
@ -101,27 +172,34 @@ struct sch_{{schedule_name}} {
|
||||
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
};
|
||||
|
||||
torch::Tensor
|
||||
impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) {
|
||||
bool with_C = args.C.has_value(), with_scales = args.scales.has_value(),
|
||||
with_zeropoints = args.zeros.has_value();
|
||||
|
||||
{% for s in specializations %}
|
||||
if (with_C == {{s.with_C|lower}}
|
||||
&& with_zeropoints == {{s.with_zeropoints|lower}}
|
||||
&& with_scales == {{s.with_scales|lower}}) {
|
||||
return run_impl<Kernel<sch_{{schedule_name}}, {{s.with_C|lower}},
|
||||
{{s.with_scales|lower}}, {{s.with_zeropoints|lower}}>>(args);
|
||||
}{% endfor %}
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "for the sake of compile times and binary size machete_mm(..) is "
|
||||
" not implemented for with_C=", with_C, ", with_scales=", with_scales,
|
||||
", with_zeropoints=", with_zeropoints,
|
||||
" (for {{type_name}}_sch_{{schedule_name}})");
|
||||
}
|
||||
{% endfor %}
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set schs = impl_config.schedules -%}
|
||||
{% set type_sig = gen_type_sig(t) -%}
|
||||
|
||||
template<typename Sch>
|
||||
using Kernel_{{type_sig}} = MacheteKernelTemplate<
|
||||
{{DataTypeTag[t.a]}}, // ElementA
|
||||
{{DataTypeTag[t.b]}}, // ElementB
|
||||
{{DataTypeTag[t.out]}}, // ElementD
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[t.b_group_scale]}}, // GroupScaleT
|
||||
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
||||
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
||||
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
|
||||
Sch>;
|
||||
|
||||
{% for sch in schs %}
|
||||
{% set sch_sig = gen_sch_sig(sch) -%}
|
||||
torch::Tensor
|
||||
impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) {
|
||||
return run_impl<Kernel_{{type_sig}}<sch_{{sch_sig}}>>(args);
|
||||
}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
@ -130,26 +208,34 @@ PREPACK_TEMPLATE = """
|
||||
#include "../machete_prepack_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
using PrepackBDispatcher_ = PrepackBDispatcher<
|
||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
|
||||
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
|
||||
|
||||
using PrepackedLayoutB = PrepackedLayoutBTemplate<
|
||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>;
|
||||
|
||||
template <>
|
||||
torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) {
|
||||
return prepack_impl<PrepackedLayoutB>(B);
|
||||
torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||
auto convert_type = args.maybe_group_scales_type.value_or(args.a_type);
|
||||
{%- for t in types %}
|
||||
{% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %}
|
||||
if (args.a_type == {{TorchTypeTag[t.a]}}
|
||||
&& args.b_type.size_bits() == {{t.b_num_bits}}
|
||||
&& convert_type == {{TorchTypeTag[t.convert]}}) {
|
||||
return prepack_impl<
|
||||
PrepackedLayoutBTemplate<
|
||||
{{DataTypeTag[t.a]}}, // ElementA
|
||||
{{DataTypeTag[b_type]}}, // ElementB
|
||||
{{DataTypeTag[t.convert]}}, // ElementConvert
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>
|
||||
>(args.B);
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"prepack_B_dispatch(..) is not implemented for "
|
||||
"atype = ", args.a_type,
|
||||
", b_type = ", args.b_type.str(),
|
||||
", with_group_scales_type= ", args.maybe_group_scales_type ?
|
||||
toString(*args.maybe_group_scales_type) : "None");
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
@ -166,32 +252,34 @@ class ScheduleConfig:
|
||||
tile_scheduler: TileSchedulerType
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class TypeConfig:
|
||||
element_a: DataType
|
||||
element_b: Union[DataType, VLLMDataType]
|
||||
element_b_scale: DataType
|
||||
element_b_zeropoint: DataType
|
||||
element_d: DataType
|
||||
a: DataType
|
||||
b: Union[DataType, VLLMDataType]
|
||||
b_group_scale: DataType
|
||||
b_group_zeropoint: DataType
|
||||
b_channel_scale: DataType
|
||||
a_token_scale: DataType
|
||||
out: DataType
|
||||
accumulator: DataType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PrepackTypeConfig:
|
||||
a: DataType
|
||||
b_num_bits: int
|
||||
convert: DataType
|
||||
accumulator: DataType
|
||||
|
||||
|
||||
@dataclass
|
||||
class Specialization:
|
||||
with_C: bool
|
||||
with_zeropoints: bool
|
||||
with_scales: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImplConfig:
|
||||
type_config: TypeConfig
|
||||
schedule_configs: List[ScheduleConfig]
|
||||
specializations: List[Specialization]
|
||||
types: TypeConfig
|
||||
schedules: List[ScheduleConfig]
|
||||
heuristic: List[Tuple[Optional[str], ScheduleConfig]]
|
||||
|
||||
|
||||
def generate_schedule_name(schedule_config: ScheduleConfig) -> str:
|
||||
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
tile_shape = (
|
||||
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
||||
)
|
||||
@ -209,40 +297,34 @@ def generate_schedule_name(schedule_config: ScheduleConfig) -> str:
|
||||
f"_{epilogue_schedule}_{tile_scheduler}")
|
||||
|
||||
|
||||
# mostly unique shorter schedule_name
|
||||
def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str:
|
||||
# mostly unique shorter sch_sig
|
||||
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
kernel_terse_names_replace = {
|
||||
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
|
||||
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
||||
"StreamKScheduler": "streamK",
|
||||
}
|
||||
|
||||
schedule_name = generate_schedule_name(schedule_config)
|
||||
sch_sig = generate_sch_sig(schedule_config)
|
||||
for orig, terse in kernel_terse_names_replace.items():
|
||||
schedule_name = schedule_name.replace(orig, terse)
|
||||
return schedule_name
|
||||
sch_sig = sch_sig.replace(orig, terse)
|
||||
return sch_sig
|
||||
|
||||
|
||||
# unique type_name
|
||||
def generate_type_signature(kernel_type_config: TypeConfig):
|
||||
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
|
||||
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
|
||||
element_d = VLLMDataTypeNames[kernel_type_config.element_d]
|
||||
accumulator = VLLMDataTypeNames[kernel_type_config.accumulator]
|
||||
element_scale = VLLMDataTypeNames[kernel_type_config.element_b_scale]
|
||||
element_zeropoint = VLLMDataTypeNames[
|
||||
kernel_type_config.element_b_zeropoint]
|
||||
|
||||
return (f"{element_a}{element_b}{element_d}"
|
||||
f"{accumulator}{element_scale}{element_zeropoint}")
|
||||
def generate_type_signature(kernel_types: TypeConfig):
|
||||
return str("".join([
|
||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||
for field in fields(TypeConfig)
|
||||
]))
|
||||
|
||||
|
||||
# non-unique shorter type_name
|
||||
def generate_terse_type_signature(kernel_type_config: TypeConfig):
|
||||
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
|
||||
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
|
||||
|
||||
return f"{element_a}{element_b}"
|
||||
def generate_type_option_name(kernel_types: TypeConfig):
|
||||
return ", ".join([
|
||||
f"{field.name.replace('b_', 'with_')+'_type'}=" +
|
||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||
for field in fields(TypeConfig)
|
||||
])
|
||||
|
||||
|
||||
def is_power_of_two(n):
|
||||
@ -263,13 +345,36 @@ def to_cute_constant(value: List[int]):
|
||||
return _to_cute_constant(value)
|
||||
|
||||
|
||||
def unique_schedules(impl_configs: List[ImplConfig]):
|
||||
return list(
|
||||
set(sch for impl_config in impl_configs
|
||||
for sch in impl_config.schedules))
|
||||
|
||||
|
||||
def unsigned_type_with_bitwidth(num_bits):
|
||||
return {
|
||||
4: DataType.u4,
|
||||
8: DataType.u8,
|
||||
16: DataType.u16,
|
||||
32: DataType.u32,
|
||||
64: DataType.u64,
|
||||
}[num_bits]
|
||||
|
||||
|
||||
template_globals = {
|
||||
"void": DataType.void,
|
||||
"DataTypeTag": VLLMDataTypeTag,
|
||||
"VLLMScalarTypeTag": VLLMDataTypeVLLMScalarTypeTag,
|
||||
"TorchTypeTag": VLLMDataTypeTorchDataTypeTag,
|
||||
"KernelScheduleTag": VLLMKernelScheduleTag,
|
||||
"EpilogueScheduleTag": EpilogueScheduleTag,
|
||||
"TileSchedulerTag": TileSchedulerTag,
|
||||
"to_cute_constant": to_cute_constant,
|
||||
"gen_sch_name": generate_terse_schedule_name,
|
||||
"gen_sch_sig": generate_terse_sch_sig,
|
||||
"gen_type_sig": generate_type_signature,
|
||||
"unique_schedules": unique_schedules,
|
||||
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
||||
"gen_type_option_name": generate_type_option_name
|
||||
}
|
||||
|
||||
|
||||
@ -284,42 +389,82 @@ mm_impl_template = create_template(IMPL_TEMPLATE)
|
||||
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
||||
|
||||
|
||||
def create_sources(impl_config: ImplConfig, num_impl_files=1):
|
||||
def create_sources(impl_configs: List[ImplConfig], num_impl_files=8):
|
||||
sources = []
|
||||
|
||||
type_name = generate_type_signature(impl_config.type_config)
|
||||
terse_type_name = generate_terse_type_signature(impl_config.type_config)
|
||||
|
||||
sources.append((
|
||||
f"machete_mm_{terse_type_name}",
|
||||
mm_dispatch_template.render(type_name=type_name,
|
||||
type_config=impl_config.type_config,
|
||||
schedules=impl_config.schedule_configs,
|
||||
heuristic=impl_config.heuristic),
|
||||
"machete_mm_dispatch",
|
||||
mm_dispatch_template.render(impl_configs=impl_configs),
|
||||
))
|
||||
|
||||
prepack_types = []
|
||||
for impl_config in impl_configs:
|
||||
convert_type = impl_config.types.a \
|
||||
if impl_config.types.b_group_scale == DataType.void \
|
||||
else impl_config.types.b_group_scale
|
||||
prepack_types.append(
|
||||
PrepackTypeConfig(
|
||||
a=impl_config.types.a,
|
||||
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
|
||||
convert=convert_type,
|
||||
accumulator=impl_config.types.accumulator,
|
||||
))
|
||||
|
||||
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
||||
# For now we we can just use the first accumulator type seen since
|
||||
# the tensor core shapes/layouts don't vary based on accumulator
|
||||
# type so we can generate less code this way
|
||||
return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert)
|
||||
|
||||
unique_prepack_types = []
|
||||
prepack_types_seen = set()
|
||||
for prepack_type in prepack_types:
|
||||
key = prepacked_type_key(prepack_type)
|
||||
if key not in prepack_types_seen:
|
||||
unique_prepack_types.append(prepack_type)
|
||||
prepack_types_seen.add(key)
|
||||
|
||||
sources.append((
|
||||
f"machete_prepack_{terse_type_name}",
|
||||
prepack_dispatch_template.render(
|
||||
type_name=type_name,
|
||||
type_config=impl_config.type_config,
|
||||
),
|
||||
"machete_prepack",
|
||||
prepack_dispatch_template.render(types=unique_prepack_types, ),
|
||||
))
|
||||
|
||||
num_schedules = len(impl_config.schedule_configs)
|
||||
schedules_per_file = math.ceil(num_schedules / num_impl_files)
|
||||
for part, i in enumerate(range(0, num_schedules, schedules_per_file)):
|
||||
file_schedules = impl_config.schedule_configs[i:i + schedules_per_file]
|
||||
# Split up impls across files
|
||||
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
||||
num_impls_per_file = math.ceil(num_impls / num_impl_files)
|
||||
|
||||
files_impls: List[List[ImplConfig]] = [[]]
|
||||
|
||||
curr_num_impls_assigned = 0
|
||||
curr_impl_in_file = 0
|
||||
curr_impl_configs = deepcopy(list(reversed(impl_configs)))
|
||||
|
||||
while curr_num_impls_assigned < num_impls:
|
||||
room_left_in_file = num_impls_per_file - curr_impl_in_file
|
||||
if room_left_in_file == 0:
|
||||
files_impls.append([])
|
||||
room_left_in_file = num_impls_per_file
|
||||
curr_impl_in_file = 0
|
||||
|
||||
curr_ic = curr_impl_configs[-1]
|
||||
if len(curr_ic.schedules) >= room_left_in_file:
|
||||
# Break apart the current impl config
|
||||
tmp_ic = deepcopy(curr_ic)
|
||||
tmp_ic.schedules = curr_ic.schedules[:room_left_in_file]
|
||||
curr_ic.schedules = curr_ic.schedules[room_left_in_file:]
|
||||
files_impls[-1].append(tmp_ic)
|
||||
else:
|
||||
files_impls[-1].append(curr_ic)
|
||||
curr_impl_configs.pop()
|
||||
curr_num_impls_assigned += len(files_impls[-1][-1].schedules)
|
||||
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
||||
|
||||
for part, file_impls in enumerate(files_impls):
|
||||
sources.append((
|
||||
f"machete_mm_{terse_type_name}_impl_part{part}",
|
||||
mm_impl_template.render(
|
||||
type_name=type_name,
|
||||
type_config=impl_config.type_config,
|
||||
schedules=file_schedules,
|
||||
specializations=impl_config.specializations,
|
||||
),
|
||||
f"machete_mm_impl_part{part+1}",
|
||||
mm_impl_template.render(impl_configs=file_impls),
|
||||
))
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
@ -328,187 +473,169 @@ def generate():
|
||||
# about how this works
|
||||
SCRIPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
schedule_common_params = dict(
|
||||
sch_common_params = dict(
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)
|
||||
|
||||
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
default_tile_heuristic_config = {
|
||||
#### M = 257+
|
||||
"M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
"M > 256": ((128, 256), (2, 1, 1)),
|
||||
#### M = 129-256
|
||||
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
"M > 128": ((128, 256), (2, 1, 1)),
|
||||
#### M = 65-128
|
||||
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
"M > 64": ((128, 128), (2, 1, 1)),
|
||||
#### M = 33-64
|
||||
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
"M > 32": ((128, 64), (2, 1, 1)),
|
||||
#### M = 17-32
|
||||
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
"M > 16": ((256, 32), (2, 1, 1)),
|
||||
#### M = 1-16
|
||||
"N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
None: ((128, 16), (1, 1, 1)),
|
||||
}
|
||||
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
default_heuristic = [
|
||||
#### M = 257+
|
||||
(
|
||||
"M > 256 && K <= 16384 && N <= 4096",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 256",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 256),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 129-256
|
||||
(
|
||||
"M > 128 && K <= 4096 && N <= 4096",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 128 && K <= 8192 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 128",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 256),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 65-128
|
||||
(
|
||||
"M > 64 && K <= 4069 && N <= 4069",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64 && K <= 4069 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64 && K >= 8192 && N >= 12288",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 33-64
|
||||
(
|
||||
"M > 32 && K <= 6144 && N <= 6144",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 32 && K >= 16384 && N >= 12288",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 32",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 17-32
|
||||
(
|
||||
"M > 16 && K <= 12288 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 16",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 1-16
|
||||
(
|
||||
"N >= 26624",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
None,
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(cond, ScheduleConfig(*tile_config,
|
||||
**sch_common_params)) # type: ignore
|
||||
for cond, tile_config in default_tile_heuristic_config.items()
|
||||
]
|
||||
|
||||
# Do not use schedules = list(set(...)) because we need to make sure
|
||||
# the output list is deterministic; otherwise the generated kernel file
|
||||
# will be non-deterministic and causes ccache miss.
|
||||
schedules = []
|
||||
for _, schedule_config in default_heuristic:
|
||||
if schedule_config not in schedules:
|
||||
schedules.append(schedule_config)
|
||||
def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]):
|
||||
# Do not use schedules = list(set(...)) because we need to make sure
|
||||
# the output list is deterministic; otherwise the generated kernel file
|
||||
# will be non-deterministic and causes ccache miss.
|
||||
schedules = []
|
||||
for _, schedule_config in heuristic:
|
||||
if schedule_config not in schedules:
|
||||
schedules.append(schedule_config)
|
||||
return schedules
|
||||
|
||||
impl_configs = []
|
||||
|
||||
GPTQ_kernel_type_configs = list(
|
||||
TypeConfig(
|
||||
element_a=element_a,
|
||||
element_b=element_b,
|
||||
element_b_scale=element_a,
|
||||
element_b_zeropoint=element_a,
|
||||
element_d=element_a,
|
||||
a=a,
|
||||
b=b,
|
||||
b_group_scale=a,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.void,
|
||||
a_token_scale=DataType.void,
|
||||
out=a,
|
||||
accumulator=DataType.f32,
|
||||
) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
||||
for element_a in (DataType.f16, DataType.bf16))
|
||||
|
||||
GPTQ_kernel_specializations = [
|
||||
Specialization(with_C=False, with_zeropoints=False, with_scales=True)
|
||||
]
|
||||
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
||||
for a in (DataType.f16, DataType.bf16))
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2], x[3])
|
||||
for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules),
|
||||
itertools.repeat(GPTQ_kernel_specializations),
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(GPTQ_kernel_type_configs,
|
||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||
itertools.repeat(default_heuristic))
|
||||
]
|
||||
|
||||
AWQ_kernel_type_configs = list(
|
||||
TypeConfig(
|
||||
element_a=element_a,
|
||||
element_b=element_b,
|
||||
element_b_scale=element_a,
|
||||
element_b_zeropoint=element_a,
|
||||
element_d=element_a,
|
||||
a=a,
|
||||
b=b,
|
||||
b_group_scale=a,
|
||||
b_group_zeropoint=a,
|
||||
b_channel_scale=DataType.void,
|
||||
a_token_scale=DataType.void,
|
||||
out=a,
|
||||
accumulator=DataType.f32,
|
||||
) for element_b in (DataType.u4, DataType.u8)
|
||||
for element_a in (DataType.f16, DataType.bf16))
|
||||
) for b in (DataType.u4, DataType.u8)
|
||||
for a in (DataType.f16, DataType.bf16))
|
||||
|
||||
AWQ_kernel_specializations = [
|
||||
Specialization(with_C=False, with_zeropoints=True, with_scales=True)
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(AWQ_kernel_type_configs,
|
||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||
itertools.repeat(default_heuristic))
|
||||
]
|
||||
|
||||
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
# TODO (LucasWilkinson): Further tuning required
|
||||
qqq_tile_heuristic_config = {
|
||||
#### M = 257+
|
||||
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
# "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
# "M > 256": ((128, 256), (2, 1, 1)),
|
||||
"M > 256": ((128, 128), (2, 1, 1)),
|
||||
#### M = 129-256
|
||||
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
# "M > 128": ((128, 256), (2, 1, 1)),
|
||||
"M > 128": ((128, 128), (2, 1, 1)),
|
||||
#### M = 65-128
|
||||
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
"M > 64": ((128, 128), (2, 1, 1)),
|
||||
#### M = 33-64
|
||||
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
# Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
#"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
"M > 32": ((128, 64), (2, 1, 1)),
|
||||
#### M = 17-32
|
||||
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
"M > 16": ((256, 32), (2, 1, 1)),
|
||||
#### M = 1-16
|
||||
"N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
None: ((128, 16), (1, 1, 1)),
|
||||
}
|
||||
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
qqq_heuristic = [
|
||||
(cond, ScheduleConfig(*tile_config,
|
||||
**sch_common_params)) # type: ignore
|
||||
for cond, tile_config in qqq_tile_heuristic_config.items()
|
||||
]
|
||||
|
||||
QQQ_kernel_types = [
|
||||
*(TypeConfig(
|
||||
a=DataType.s8,
|
||||
b=VLLMDataType.u4b8,
|
||||
b_group_scale=b_group_scale,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.f32,
|
||||
a_token_scale=DataType.f32,
|
||||
out=DataType.f16,
|
||||
accumulator=DataType.s32,
|
||||
) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
*(TypeConfig(
|
||||
a=DataType.e4m3,
|
||||
b=VLLMDataType.u4b8,
|
||||
b_group_scale=b_group_scale,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.f32,
|
||||
a_token_scale=DataType.f32,
|
||||
out=DataType.f16,
|
||||
accumulator=DataType.f32,
|
||||
) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
]
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2], x[3])
|
||||
for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules),
|
||||
itertools.repeat(AWQ_kernel_specializations),
|
||||
itertools.repeat(default_heuristic))
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(QQQ_kernel_types,
|
||||
itertools.repeat(get_unique_schedules(qqq_heuristic)),
|
||||
itertools.repeat(qqq_heuristic))
|
||||
]
|
||||
|
||||
output_dir = os.path.join(SCRIPT_DIR, "generated")
|
||||
@ -521,12 +648,11 @@ def generate():
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# Render each group of configurations into separate files
|
||||
for impl_config in impl_configs:
|
||||
for filename, code in create_sources(impl_config):
|
||||
filepath = os.path.join(output_dir, f"{filename}.cu")
|
||||
with open(filepath, "w") as output_file:
|
||||
output_file.write(code)
|
||||
print(f"Rendered template to {filepath}")
|
||||
for filename, code in create_sources(impl_configs):
|
||||
filepath = os.path.join(output_dir, f"{filename}.cu")
|
||||
with open(filepath, "w") as output_file:
|
||||
output_file.write(code)
|
||||
print(f"Rendered template to {filepath}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -171,6 +171,10 @@ struct MacheteCollectiveMma {
|
||||
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
|
||||
Int<DispatchPolicy::Stages>{})));
|
||||
|
||||
using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy(
|
||||
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
|
||||
Int<DispatchPolicy::Stages>{})));
|
||||
|
||||
using SmemLayoutAtomARowMajor =
|
||||
decltype(rs_smem_selector<GmmaMajorA, ElementA,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
@ -288,14 +292,7 @@ struct MacheteCollectiveMma {
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0,
|
||||
"SmemLayoutAtomScale must evenly divide tile k shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutACopy = decltype(tile_to_shape(
|
||||
SmemLayoutAtomARowMajor{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
|
||||
Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
|
||||
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
|
||||
@ -428,12 +425,12 @@ struct MacheteCollectiveMma {
|
||||
// clang-format on
|
||||
|
||||
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
|
||||
using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset(
|
||||
using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset_copy(
|
||||
make_shape(int32_t(0), int32_t(0), int32_t(0)))));
|
||||
|
||||
using ATensor = decltype(make_tensor(
|
||||
get_logical_ptr(static_cast<InternalElementA const*>(nullptr)),
|
||||
shape(GmemLayoutA::TVbNbKL_to_offset(
|
||||
shape(GmemLayoutA::TVbNbKL_to_offset_copy(
|
||||
make_shape(int32_t(0), int32_t(0), int32_t(0)))),
|
||||
PrepackedStrideA{}));
|
||||
|
||||
@ -450,8 +447,8 @@ struct MacheteCollectiveMma {
|
||||
|
||||
static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) {
|
||||
return make_tma_copy<TmaElementA>(
|
||||
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
|
||||
shape(SmemLayoutA{}(_, _, cute::Int<0>{})),
|
||||
GmemTiledCopyA{}, tensor_a, SmemLayoutACopy{}(_, _, cute::Int<0>{}),
|
||||
shape(SmemLayoutACopy{}(_, _, cute::Int<0>{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
}
|
||||
|
||||
@ -584,7 +581,7 @@ struct MacheteCollectiveMma {
|
||||
typename Params::TMA_Scale tma_load_scale;
|
||||
typename Params::TMA_Zero tma_load_zero;
|
||||
|
||||
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
|
||||
auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L));
|
||||
tma_load_a = make_tma_copy_A(
|
||||
make_logical_tensor(ptr_A, shape(layout), stride(layout)));
|
||||
|
||||
@ -722,7 +719,7 @@ struct MacheteCollectiveMma {
|
||||
// (TILE_V,TILE_B,m,k,l)
|
||||
auto make_gA_mkl = [&]() {
|
||||
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
|
||||
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
|
||||
auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L));
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout));
|
||||
return local_tile(mA_mkl,
|
||||
make_shape(size<0>(layout), PPBlocksPerTile_MK{}),
|
||||
|
@ -21,6 +21,8 @@
|
||||
|
||||
#include "cutlass_extensions/cute_utils.cuh"
|
||||
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "machete_collective_builder.cuh"
|
||||
#include "machete_prepacked_layout.cuh"
|
||||
#include "machete_interleaving_utils.cuh"
|
||||
@ -37,27 +39,42 @@ using namespace cute;
|
||||
// W is quantized, in this situation or right-hand operand is quantized so
|
||||
// we compute the transpose to move it to the left-hand side.
|
||||
template <typename ElementA_, typename ElementB_, typename ElementD_,
|
||||
typename AccumulatorT, typename ScaleT, typename ZeroT,
|
||||
class KernelSchedule, typename ScheduleConfig, bool with_C,
|
||||
bool with_scales, bool with_zeropoints>
|
||||
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
|
||||
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
|
||||
typename ScheduleConfig>
|
||||
struct MacheteKernelTemplate {
|
||||
static constexpr bool with_C = false; // not ever used
|
||||
static constexpr bool with_group_scales = !std::is_same_v<GroupScaleT, void>;
|
||||
static constexpr bool with_group_zeropoints =
|
||||
!std::is_same_v<GroupZeroT, void>;
|
||||
static constexpr bool with_channel_scales =
|
||||
!std::is_same_v<ChannelScaleT, void>;
|
||||
static constexpr bool with_token_scales = !std::is_same_v<TokenScaleT, void>;
|
||||
|
||||
using MmaType = ElementA_;
|
||||
using ElementA = ElementA_;
|
||||
using ElementB = ElementB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementC = cute::conditional_t<with_C, ElementD, void>;
|
||||
using ElementZ = ZeroT;
|
||||
using ElementS = ScaleT;
|
||||
|
||||
using ElementAccumulator =
|
||||
AccumulatorT; // Element type for internal accumulation
|
||||
using ElementAccumulator = AccumulatorT;
|
||||
using ElementCompute = AccumulatorT; // For Epilogue
|
||||
// Use dummy values when we don't have scales or zeropoints
|
||||
using ElementZGroup =
|
||||
cute::conditional_t<with_group_zeropoints, GroupZeroT, MmaType>;
|
||||
using ElementSGroup =
|
||||
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
|
||||
using ElementConvertGroup =
|
||||
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
|
||||
using ElementSChannel =
|
||||
cute::conditional_t<with_channel_scales, ChannelScaleT, AccumulatorT>;
|
||||
using ElementSToken =
|
||||
cute::conditional_t<with_token_scales, TokenScaleT, AccumulatorT>;
|
||||
|
||||
using BTypeTuple = cute::conditional_t<
|
||||
with_scales,
|
||||
cute::conditional_t<with_zeropoints,
|
||||
cute::tuple<ElementB, ElementS, ElementZ>,
|
||||
cute::tuple<ElementB, ElementS>>,
|
||||
with_group_scales,
|
||||
cute::conditional_t<with_group_zeropoints,
|
||||
cute::tuple<ElementB, ElementSGroup, ElementZGroup>,
|
||||
cute::tuple<ElementB, ElementSGroup>>,
|
||||
ElementB>;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
@ -71,8 +88,8 @@ struct MacheteKernelTemplate {
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
|
||||
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||
using StrideS = cutlass::detail::TagToStrideA_t<LayoutScale>;
|
||||
using StrideZ = StrideS;
|
||||
using StrideSGroup = cutlass::detail::TagToStrideA_t<LayoutScale>;
|
||||
using StrideZGroup = StrideSGroup;
|
||||
|
||||
using LayoutA_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
@ -85,8 +102,8 @@ struct MacheteKernelTemplate {
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using PrepackedLayoutB =
|
||||
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT,
|
||||
LayoutA_Transpose, KernelSchedule>;
|
||||
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
|
||||
AccumulatorT, LayoutA_Transpose, KernelSchedule>;
|
||||
|
||||
static int constexpr TileShapeK =
|
||||
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
|
||||
@ -103,12 +120,42 @@ struct MacheteKernelTemplate {
|
||||
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
|
||||
using TileScheduler = typename ScheduleConfig::TileScheduler;
|
||||
|
||||
static_assert(
|
||||
(!with_channel_scales && !with_token_scales) ||
|
||||
((with_channel_scales && with_token_scales) &&
|
||||
std::is_same_v<ElementSChannel, ElementSToken>),
|
||||
"Currently token and channel scales (if present) must be the same type");
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||
ElementD, EpilogueSchedule>;
|
||||
|
||||
// Currently only supports float scales
|
||||
using ChTokScalesEpilogue =
|
||||
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
||||
EpilogueDescriptor>;
|
||||
static_assert((with_channel_scales || with_token_scales) ||
|
||||
(std::is_same_v<ElementSChannel, float> &&
|
||||
std::is_same_v<ElementSToken, float>),
|
||||
"Currently token and channel scales (if present) must be float "
|
||||
"(and if one is present the other must be too)");
|
||||
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using EVTCompute =
|
||||
std::conditional_t<with_channel_scales || with_token_scales,
|
||||
typename ChTokScalesEpilogue::EVTCompute,
|
||||
StoreEpilogueCompute>;
|
||||
|
||||
// EVTCompute
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose,
|
||||
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD,
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose,
|
||||
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::VLLMCollectiveBuilder<
|
||||
@ -131,26 +178,44 @@ struct MacheteKernelTemplate {
|
||||
using MainloopArguments = typename GemmKernel::MainloopArguments;
|
||||
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
|
||||
|
||||
template <typename ShapeA, typename ShapeC, typename ShapeD, typename ShapeS,
|
||||
typename ShapeZ>
|
||||
static Arguments create_arguments(
|
||||
cudaStream_t stream,
|
||||
ElementA const* A_ptr, // A is an MxK matrix
|
||||
Layout<ShapeA, StrideA> const& layout_A,
|
||||
ElementB const* B_ptr, // B is an KxN prepacked matrix
|
||||
ElementD* D_ptr, // D is an MxN matrix
|
||||
Layout<ShapeD, StrideD> const& layout_D,
|
||||
ElementC const* C_ptr, // C is an MxN matrix
|
||||
std::optional<Layout<ShapeC, StrideC>> const& layout_C,
|
||||
ElementS const* S_ptr, // S is an scale_KxN matrix
|
||||
std::optional<Layout<ShapeS, StrideS>> const& layout_S,
|
||||
ElementZ const* Z_ptr, // Z is an scale_KxN matrix
|
||||
std::optional<Layout<ShapeZ, StrideZ>> const& layout_Z,
|
||||
ElementCompute alpha, ElementCompute beta,
|
||||
std::optional<int> maybe_group_size) {
|
||||
static_assert(!with_zeropoints || with_scales);
|
||||
torch::Tensor const& A, // MxK matrix
|
||||
torch::Tensor const& B, // KxN prepacked matrix
|
||||
torch::Tensor& D, // MxN matrix
|
||||
c10::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
|
||||
c10::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
|
||||
c10::optional<int64_t> maybe_group_size,
|
||||
c10::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
|
||||
c10::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
|
||||
{
|
||||
static_assert(!with_group_zeropoints || with_group_scales);
|
||||
|
||||
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
|
||||
int M = A.size(0), N = B.size(1), K = A.size(1);
|
||||
TORCH_CHECK(D.size(0) == M && D.size(1) == N);
|
||||
|
||||
auto layout_A = make_cute_layout<StrideA>(A, "A");
|
||||
auto layout_D = make_cute_layout<StrideD>(D, "D");
|
||||
auto layout_S_group =
|
||||
maybe_make_cute_layout<StrideSGroup>(maybe_g_scales, "group_scales");
|
||||
auto layout_Z_group =
|
||||
maybe_make_cute_layout<StrideZGroup>(maybe_g_zeros, "group_zeros");
|
||||
int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0;
|
||||
int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0;
|
||||
|
||||
auto unwrap = [](auto const& t) {
|
||||
return t ? t->const_data_ptr() : nullptr;
|
||||
};
|
||||
auto A_ptr = static_cast<ElementA const*>(A.const_data_ptr());
|
||||
auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
|
||||
auto D_ptr = static_cast<ElementD*>(D.mutable_data_ptr());
|
||||
auto S_group_ptr =
|
||||
static_cast<ElementSGroup const*>(unwrap(maybe_g_scales));
|
||||
auto Z_group_ptr = static_cast<ElementZGroup const*>(unwrap(maybe_g_zeros));
|
||||
auto S_channel_ptr =
|
||||
static_cast<ElementSChannel const*>(unwrap(maybe_ch_scales));
|
||||
auto S_token_ptr =
|
||||
static_cast<ElementSToken const*>(unwrap(maybe_tok_scales));
|
||||
|
||||
int const group_size =
|
||||
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
|
||||
@ -159,26 +224,28 @@ struct MacheteKernelTemplate {
|
||||
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
|
||||
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
|
||||
|
||||
if constexpr (with_C) {
|
||||
TORCH_CHECK(C_ptr && layout_C);
|
||||
if constexpr (with_group_scales) {
|
||||
TORCH_CHECK(S_group_ptr && layout_S_group);
|
||||
TORCH_CHECK((size<0>(*layout_S_group) == scale_k &&
|
||||
size<1>(*layout_S_group) == N));
|
||||
} else {
|
||||
TORCH_CHECK(!C_ptr, "C not supported");
|
||||
TORCH_CHECK(!S_group_ptr, "Scales not supported");
|
||||
}
|
||||
|
||||
if constexpr (with_scales) {
|
||||
TORCH_CHECK(S_ptr && layout_S);
|
||||
TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N));
|
||||
} else {
|
||||
TORCH_CHECK(!S_ptr, "Scales not supported");
|
||||
}
|
||||
|
||||
if constexpr (with_zeropoints) {
|
||||
TORCH_CHECK(Z_ptr && layout_Z);
|
||||
TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N));
|
||||
TORCH_CHECK(layout_S && *layout_Z == *layout_S,
|
||||
if constexpr (with_group_zeropoints) {
|
||||
TORCH_CHECK(Z_group_ptr && layout_Z_group);
|
||||
TORCH_CHECK((size<0>(*layout_Z_group) == scale_k &&
|
||||
size<1>(*layout_Z_group) == N));
|
||||
TORCH_CHECK(layout_S_group && *layout_Z_group == *layout_S_group,
|
||||
"Scales and zeros must have the same layout");
|
||||
} else {
|
||||
TORCH_CHECK(!Z_ptr, "Zeropoints not supported");
|
||||
TORCH_CHECK(!Z_group_ptr, "Zeropoints not supported");
|
||||
}
|
||||
|
||||
if constexpr (with_channel_scales || with_token_scales) {
|
||||
TORCH_CHECK(
|
||||
(maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) &&
|
||||
(maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1));
|
||||
}
|
||||
|
||||
// Transpose A and D
|
||||
@ -186,24 +253,33 @@ struct MacheteKernelTemplate {
|
||||
// for B (which is At)
|
||||
auto stride_At = layout_A.stride();
|
||||
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
|
||||
auto stride_Ct = stride_Dt;
|
||||
if (layout_C) {
|
||||
stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride();
|
||||
}
|
||||
|
||||
MainloopArguments mainloop_arguments{};
|
||||
EpilogueArguments epilogue_arguments{
|
||||
{alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt};
|
||||
// {Accum, C, C_layout, D, D}
|
||||
EpilogueArguments epilogue_arguments{};
|
||||
|
||||
if constexpr (with_scales && with_zeropoints) {
|
||||
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_ptr, stride_S, group_size, Z_ptr};
|
||||
} else if constexpr (with_scales) {
|
||||
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
|
||||
if constexpr (with_channel_scales || with_token_scales) {
|
||||
epilogue_arguments =
|
||||
EpilogueArguments{ChTokScalesEpilogue::prepare_args(
|
||||
*maybe_ch_scales, *maybe_tok_scales),
|
||||
nullptr,
|
||||
{},
|
||||
D_ptr,
|
||||
stride_Dt};
|
||||
} else {
|
||||
epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt};
|
||||
}
|
||||
|
||||
if constexpr (with_group_scales && with_group_zeropoints) {
|
||||
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||
mainloop_arguments = MainloopArguments{
|
||||
B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size};
|
||||
B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_group_ptr, stride_S_group, group_size, Z_group_ptr};
|
||||
} else if constexpr (with_group_scales) {
|
||||
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_group_ptr, stride_S_group, group_size};
|
||||
} else {
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
|
||||
|
@ -5,73 +5,61 @@
|
||||
|
||||
#include "machete_mm_kernel.cuh"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
namespace machete {
|
||||
|
||||
struct PyTorchArguments {
|
||||
struct MMArgs {
|
||||
torch::Tensor const& A;
|
||||
torch::Tensor const& B;
|
||||
c10::optional<torch::Tensor> const& scales;
|
||||
c10::optional<torch::Tensor> const& zeros;
|
||||
c10::optional<int64_t> group_size;
|
||||
c10::optional<torch::Tensor> const& C;
|
||||
c10::optional<double> alpha;
|
||||
c10::optional<double> beta;
|
||||
c10::optional<std::string> schedule;
|
||||
vllm::ScalarType const& b_type;
|
||||
c10::optional<at::ScalarType> const& maybe_out_type;
|
||||
c10::optional<torch::Tensor> const& maybe_group_scales;
|
||||
c10::optional<torch::Tensor> const& maybe_group_zeros;
|
||||
c10::optional<int64_t> maybe_group_size;
|
||||
c10::optional<torch::Tensor> const& maybe_channel_scales;
|
||||
c10::optional<torch::Tensor> const& maybe_token_scales;
|
||||
c10::optional<std::string> maybe_schedule;
|
||||
};
|
||||
|
||||
struct SupportedSchedulesArgs {
|
||||
at::ScalarType a_type;
|
||||
vllm::ScalarType b_type;
|
||||
c10::optional<at::ScalarType> maybe_group_scales_type;
|
||||
c10::optional<at::ScalarType> maybe_group_zeros_type;
|
||||
c10::optional<at::ScalarType> maybe_channel_scales_type;
|
||||
c10::optional<at::ScalarType> maybe_token_scales_type;
|
||||
c10::optional<at::ScalarType> maybe_out_type;
|
||||
};
|
||||
|
||||
torch::Tensor mm_dispatch(MMArgs args);
|
||||
|
||||
std::vector<std::string> supported_schedules_dispatch(
|
||||
SupportedSchedulesArgs args);
|
||||
|
||||
template <typename MacheteKernel>
|
||||
torch::Tensor run_impl(PyTorchArguments args) {
|
||||
torch::Tensor run_impl(MMArgs args) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
|
||||
|
||||
auto device = args.A.device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
|
||||
using EleA = typename MacheteKernel::ElementA;
|
||||
using EleB = typename MacheteKernel::ElementB;
|
||||
using EleC = typename MacheteKernel::ElementC;
|
||||
using EleD = typename MacheteKernel::ElementD;
|
||||
using EleScale = typename MacheteKernel::ElementS;
|
||||
using EleZero = typename MacheteKernel::ElementZ;
|
||||
|
||||
using StrideA = typename MacheteKernel::StrideA;
|
||||
using StrideC = typename MacheteKernel::StrideC;
|
||||
using StrideD = typename MacheteKernel::StrideD;
|
||||
using StrideS = typename MacheteKernel::StrideS;
|
||||
using StrideZ = typename MacheteKernel::StrideZ;
|
||||
|
||||
int M = args.A.size(0);
|
||||
int N = args.B.size(1);
|
||||
int K = args.A.size(1);
|
||||
|
||||
// Allocate output
|
||||
torch::Tensor D =
|
||||
torch::empty({M, N}, torch::TensorOptions()
|
||||
.dtype(equivalent_scalar_type_v<EleD>)
|
||||
.device(device));
|
||||
|
||||
auto const &A = args.A, &B = args.B;
|
||||
auto const &C = args.C, &scales = args.scales, &zeros = args.zeros;
|
||||
|
||||
auto layout_A = make_cute_layout<StrideA>(A, "A");
|
||||
auto layout_D = make_cute_layout<StrideD>(D, "D");
|
||||
auto layout_C = maybe_make_cute_layout<StrideC>(C, "C");
|
||||
auto layout_S = maybe_make_cute_layout<StrideS>(scales, "scales");
|
||||
auto layout_Z = maybe_make_cute_layout<StrideZ>(zeros, "zeros");
|
||||
|
||||
auto A_ptr = static_cast<EleA const*>(A.const_data_ptr());
|
||||
auto B_ptr = static_cast<EleB const*>(B.const_data_ptr());
|
||||
auto D_ptr = static_cast<EleD*>(D.mutable_data_ptr());
|
||||
auto C_ptr = static_cast<EleC const*>(C ? C->const_data_ptr() : nullptr);
|
||||
auto S_ptr =
|
||||
static_cast<EleScale const*>(scales ? scales->const_data_ptr() : nullptr);
|
||||
auto Z_ptr =
|
||||
static_cast<EleZero const*>(zeros ? zeros->const_data_ptr() : nullptr);
|
||||
torch::Tensor D = torch::empty(
|
||||
{M, N},
|
||||
torch::TensorOptions()
|
||||
.dtype(equivalent_scalar_type_v<typename MacheteKernel::ElementD>)
|
||||
.device(device));
|
||||
|
||||
auto arguments = MacheteKernel::create_arguments(
|
||||
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
|
||||
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
|
||||
args.group_size);
|
||||
stream, //
|
||||
args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros,
|
||||
args.maybe_group_size, args.maybe_channel_scales,
|
||||
args.maybe_token_scales);
|
||||
TORCH_CHECK(MacheteKernel::can_implement(arguments),
|
||||
"Machete kernel cannot be run with these arguments");
|
||||
|
||||
@ -84,12 +72,4 @@ torch::Tensor run_impl(PyTorchArguments args) {
|
||||
return D;
|
||||
};
|
||||
|
||||
template <typename ElementA, typename ElementB, typename ElementD = ElementA,
|
||||
typename AccumulatorT = float, typename ScaleT = ElementA,
|
||||
typename ZeroT = ElementA>
|
||||
struct GemmDispatcher {
|
||||
static torch::Tensor dispatch(PyTorchArguments args);
|
||||
static std::vector<std::string> supported_schedules();
|
||||
};
|
||||
|
||||
}; // namespace machete
|
@ -6,31 +6,49 @@
|
||||
|
||||
namespace machete {
|
||||
|
||||
template <typename TileShapeNKL, typename ElementB, typename BInTensor,
|
||||
typename BTiledOutTensor>
|
||||
static __global__ void prepack_B_kernel(BInTensor B_in,
|
||||
BTiledOutTensor B_tiled_out) {
|
||||
auto tB_in = local_tile(B_in, TileShapeNKL{},
|
||||
make_coord(blockIdx.x, blockIdx.y, blockIdx.z));
|
||||
auto tB_out = B_tiled_out(make_coord(_, _),
|
||||
make_coord(blockIdx.x, blockIdx.y), blockIdx.z);
|
||||
template <int threads, typename PrepackedLayoutB, typename BInTensor,
|
||||
typename ElementB>
|
||||
static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) {
|
||||
auto constexpr block_size =
|
||||
Int<size(typename PrepackedLayoutB::PPBlockShape_NK{})>{};
|
||||
auto constexpr eles_per_thread = Int<block_size / threads>{};
|
||||
static_assert(block_size % threads == 0,
|
||||
"block_size must be divisible by the number of threads");
|
||||
|
||||
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, ElementB>{},
|
||||
Layout<Shape<_4, _32>, Stride<_32, _1>>{},
|
||||
Layout<Shape<_1, _2>>{});
|
||||
// Which pre-packed are we responsible for
|
||||
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z);
|
||||
auto tB_in = local_tile(
|
||||
B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}),
|
||||
blk_coord);
|
||||
|
||||
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
|
||||
// Find the start offset in the output for this pre-packed block
|
||||
auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in));
|
||||
|
||||
Tensor thr_tile_S = thr_copy.partition_S(tB_in);
|
||||
Tensor thr_tile_D = thr_copy.partition_D(tB_out);
|
||||
// Tensor representing a 1:1 mapping to the output space in 1D
|
||||
auto tB_out_linear =
|
||||
make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord),
|
||||
make_layout(make_shape(block_size)));
|
||||
// Mapping from output space (1D) to input space
|
||||
auto tB_in_linear = make_tensor(
|
||||
tB_in.data(),
|
||||
tB_in.layout()
|
||||
.compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset()))
|
||||
.with_shape(make_shape(block_size)));
|
||||
|
||||
// Tile for this specific thread (could have used a TiledCopy but these work
|
||||
// best with 2d layouts, this is a simple 1d layout so local_tile is enough,
|
||||
// we are also not that concerned with performance for this kernel)
|
||||
auto thr_tB_in_linear =
|
||||
local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x);
|
||||
auto thr_tB_out_linear =
|
||||
local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x);
|
||||
|
||||
// Construct a register-backed Tensor with the same shape as each thread's
|
||||
// partition
|
||||
auto fragment = make_tensor<ElementB>(shape(thr_tile_D));
|
||||
auto fragment = make_tensor<ElementB>(shape(thr_tB_in_linear));
|
||||
|
||||
// Copy from GMEM to RMEM and from RMEM to GMEM
|
||||
copy(tiled_copy, thr_tile_S, fragment);
|
||||
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tile_D);
|
||||
copy(thr_tB_in_linear, fragment);
|
||||
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tB_out_linear);
|
||||
}
|
||||
|
||||
template <typename PrepackedLayoutB, typename InLayout>
|
||||
@ -44,18 +62,15 @@ static void prepack_B_template(
|
||||
|
||||
TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
|
||||
TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
|
||||
TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0);
|
||||
|
||||
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
|
||||
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
|
||||
auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{});
|
||||
auto L_tiles = size<2>(B_layout);
|
||||
|
||||
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
|
||||
auto B_tiled_out =
|
||||
make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset);
|
||||
|
||||
prepack_B_kernel<TileShapeNKL, typename PrepackedLayoutB::ElementB>
|
||||
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_tiled_out);
|
||||
prepack_B_kernel<128, PrepackedLayoutB>
|
||||
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_out_ptr);
|
||||
}
|
||||
|
||||
}; // namespace machete
|
@ -2,9 +2,17 @@
|
||||
|
||||
#include "machete_prepack_kernel.cuh"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
namespace machete {
|
||||
|
||||
struct PrepackBArgs {
|
||||
torch::Tensor const& B;
|
||||
at::ScalarType a_type;
|
||||
vllm::ScalarType b_type;
|
||||
c10::optional<at::ScalarType> maybe_group_scales_type;
|
||||
};
|
||||
|
||||
template <typename PrepackedLayoutB>
|
||||
torch::Tensor prepack_impl(torch::Tensor const B) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
|
||||
@ -61,11 +69,6 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
|
||||
return D;
|
||||
};
|
||||
|
||||
template <typename ElementA, typename ElementB, typename ElementD,
|
||||
typename AccumulatorT = float, typename ScaleT = cutlass::half_t,
|
||||
typename ZeroT = cutlass::half_t>
|
||||
struct PrepackBDispatcher {
|
||||
static torch::Tensor dispatch(torch::Tensor B);
|
||||
};
|
||||
torch::Tensor prepack_B_dispatch(PrepackBArgs args);
|
||||
|
||||
}; // namespace machete
|
@ -41,7 +41,7 @@ struct IlvBlkLayoutAuto {};
|
||||
// The contract here is that the `TiledMma` determined below matches the one
|
||||
// ultimately used in the kernel. (this is also why the other element types are
|
||||
// required along with the kernel schedule)
|
||||
template <typename ElementA_, typename ElementB_, typename ElementD_,
|
||||
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
|
||||
typename AccumulatorT, class LayoutB, class KernelSchedule,
|
||||
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
|
||||
// clang-format on
|
||||
@ -49,20 +49,27 @@ struct PrepackedLayoutBTemplate {
|
||||
using MmaType = ElementA_;
|
||||
using ElementA = ElementA_;
|
||||
using ElementB = ElementB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAccumulator =
|
||||
AccumulatorT; // Element type for internal accumulation
|
||||
using ElementAccumulator = AccumulatorT;
|
||||
using ElementMma = MmaType;
|
||||
|
||||
// Only use interleaved layouts for subbyte weights, prmt instructions makes
|
||||
// non-interleaved layouts for 8bit+ weights efficient enough we don't need
|
||||
// iterleaved layouts
|
||||
// Interleave for 4bit bit types when we are not upconverting to fp8 or int8,
|
||||
// in those cases case we use a LUT using prmt instructions to upconvert and
|
||||
// is more efficient if the data is not interleaved For 8bit+ prmt
|
||||
// instructions makes non-interleaved layouts efficient enough we don't need
|
||||
// iterleaved layouts (and can reuse more of the existing cutlass converts)
|
||||
static constexpr bool should_interleave =
|
||||
sizeof_bits_v<ElementB> <= 4 &&
|
||||
!std::is_same_v<ElementConvert_, cutlass::float_e4m3_t> &&
|
||||
!std::is_same_v<ElementConvert_, int8_t>;
|
||||
|
||||
// Only use interleaved layouts for subbyte weights,
|
||||
using IlvdBlkLayout = std::conditional_t<
|
||||
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
|
||||
std::conditional_t<sizeof_bits_v<ElementB> <= 4,
|
||||
decltype(get_interleaved_blk_layout<
|
||||
ElementB, sizeof_bits_v<ElementA>, 32>()),
|
||||
void>,
|
||||
std::conditional_t<
|
||||
should_interleave,
|
||||
decltype(get_interleaved_blk_layout<
|
||||
ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
|
||||
void>,
|
||||
IlvBlkLayout_>;
|
||||
|
||||
// TODO (LucasWilkinson): compare the performance for other sizes
|
||||
@ -135,7 +142,8 @@ struct PrepackedLayoutBTemplate {
|
||||
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
|
||||
auto frgV = get<1, 0>(layout_no_interleave);
|
||||
auto ilvdBlk = IlvdBlkLayout{};
|
||||
static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4");
|
||||
static_assert(size(frgV) % size(ilvdBlk) == 0,
|
||||
"FrgV must be divisible by size(ilvdBlk)");
|
||||
auto ilvd_FrgV = make_layout(
|
||||
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
|
||||
make_stride(stride(ilvdBlk), size(ilvdBlk)));
|
||||
@ -175,6 +183,15 @@ struct PrepackedLayoutBTemplate {
|
||||
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
||||
}
|
||||
|
||||
// ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy(
|
||||
Shape_NKL shape_mkl) {
|
||||
auto layout = TVbNbKL_to_offset(shape_mkl);
|
||||
return make_layout(coalesce(get<0>(layout)), get<1>(layout),
|
||||
get<2>(layout));
|
||||
}
|
||||
|
||||
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
|
||||
@ -197,6 +214,19 @@ struct PrepackedLayoutBTemplate {
|
||||
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
||||
}
|
||||
|
||||
// (BlocksN, BlocksK, L) -> (storage_idx)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
auto stride = size(PPBlockShape_NK{});
|
||||
|
||||
// (BlocksN, BlocksK, L) -> (storage_idx)
|
||||
return make_layout(blocks_shape, compact_col_major(blocks_shape, stride));
|
||||
}
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||
template <class Shape_NKL>
|
||||
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
|
||||
|
@ -8,89 +8,61 @@ namespace machete {
|
||||
|
||||
using namespace vllm;
|
||||
|
||||
//
|
||||
// Utils (type dispatching)
|
||||
//
|
||||
|
||||
template <typename Fn>
|
||||
static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
|
||||
if (type == vllm::kU4) {
|
||||
return fn(cutlass::uint4b_t{});
|
||||
} else if (type == vllm::kU8) {
|
||||
return fn(cutlass::uint8_t{});
|
||||
} else if (type == vllm::kU4B8) {
|
||||
return fn(cutlass::vllm_uint4b8_t{});
|
||||
} else if (type == vllm::kU8B128) {
|
||||
return fn(cutlass::vllm_uint8b128_t{});
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported type ", type.str());
|
||||
}
|
||||
std::vector<std::string> supported_schedules(
|
||||
at::ScalarType a_type, int64_t b_type_id,
|
||||
c10::optional<at::ScalarType> maybe_group_scales_type,
|
||||
c10::optional<at::ScalarType> maybe_group_zeros_type,
|
||||
c10::optional<at::ScalarType> maybe_channel_scales_type,
|
||||
c10::optional<at::ScalarType> maybe_token_scales_type,
|
||||
c10::optional<at::ScalarType> maybe_out_type) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return supported_schedules_dispatch({
|
||||
.a_type = a_type,
|
||||
.b_type = b_type,
|
||||
.maybe_group_scales_type = maybe_group_scales_type,
|
||||
.maybe_group_zeros_type = maybe_group_zeros_type,
|
||||
.maybe_channel_scales_type = maybe_channel_scales_type,
|
||||
.maybe_token_scales_type = maybe_token_scales_type,
|
||||
.maybe_out_type = maybe_out_type,
|
||||
});
|
||||
}
|
||||
|
||||
#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
|
||||
AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)
|
||||
|
||||
#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
||||
AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))
|
||||
|
||||
//
|
||||
// Interface
|
||||
//
|
||||
|
||||
std::vector<std::string> supported_schedules(ScalarTypeId const btype_id) {
|
||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||
vllm::ScalarType b_type = ScalarType::from_id(btype_id);
|
||||
return scalar_type_dispatch(b_type, [&](auto BType) {
|
||||
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
|
||||
});
|
||||
#else
|
||||
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
||||
#endif
|
||||
torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
int64_t b_type_id,
|
||||
c10::optional<at::ScalarType> const& maybe_out_type,
|
||||
c10::optional<torch::Tensor> const& maybe_group_scales,
|
||||
c10::optional<torch::Tensor> const& maybe_group_zeros,
|
||||
c10::optional<int64_t> maybe_group_size,
|
||||
c10::optional<torch::Tensor> const& maybe_channel_scales,
|
||||
c10::optional<torch::Tensor> const& maybe_token_scales,
|
||||
c10::optional<std::string> maybe_schedule) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return mm_dispatch({.A = A,
|
||||
.B = B,
|
||||
.b_type = b_type,
|
||||
.maybe_out_type = maybe_out_type,
|
||||
.maybe_group_scales = maybe_group_scales,
|
||||
.maybe_group_zeros = maybe_group_zeros,
|
||||
.maybe_group_size = maybe_group_size,
|
||||
.maybe_channel_scales = maybe_channel_scales,
|
||||
.maybe_token_scales = maybe_token_scales,
|
||||
.maybe_schedule = maybe_schedule});
|
||||
}
|
||||
|
||||
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
ScalarTypeId const btype_id,
|
||||
c10::optional<torch::Tensor> const& scales,
|
||||
c10::optional<torch::Tensor> const& zeros,
|
||||
c10::optional<int64_t> group_size,
|
||||
c10::optional<torch::Tensor> const& C,
|
||||
c10::optional<double> alpha, c10::optional<double> beta,
|
||||
c10::optional<std::string> schedule) {
|
||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||
ScalarType const btype = ScalarType::from_id(btype_id);
|
||||
auto args = PyTorchArguments{.A = A,
|
||||
.B = B,
|
||||
.scales = scales,
|
||||
.zeros = zeros,
|
||||
.group_size = group_size,
|
||||
.C = C,
|
||||
.alpha = alpha,
|
||||
.beta = beta,
|
||||
.schedule = schedule};
|
||||
|
||||
return scalar_type_dispatch(btype, [&](auto BType) {
|
||||
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
|
||||
A.scalar_type(), "machete_gemm", [&] {
|
||||
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
|
||||
return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
|
||||
});
|
||||
});
|
||||
#else
|
||||
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
||||
#endif
|
||||
}
|
||||
|
||||
torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) {
|
||||
ScalarType const btype = ScalarType::from_id(btype_id);
|
||||
return scalar_type_dispatch(btype, [&](auto BType) {
|
||||
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
|
||||
});
|
||||
torch::Tensor prepack_B(
|
||||
torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id,
|
||||
c10::optional<at::ScalarType> const& maybe_group_scales_type) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return prepack_B_dispatch(
|
||||
{.B = B,
|
||||
.a_type = a_type,
|
||||
.b_type = b_type,
|
||||
.maybe_group_scales_type = maybe_group_scales_type});
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("machete_prepack_B", &prepack_B);
|
||||
m.impl("machete_gemm", &gemm);
|
||||
m.impl("machete_mm", &mm);
|
||||
}
|
||||
|
||||
// use CatchAll since supported_schedules has no tensor arguments
|
||||
|
@ -203,13 +203,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
||||
ops.def("machete_supported_schedules(int btype) -> str[]");
|
||||
ops.def(
|
||||
"machete_gemm(Tensor A, Tensor B, int btype, "
|
||||
" Tensor? scales, Tensor? zeros, int? group_size, "
|
||||
" Tensor? C, float? alpha, float? beta, str? schedule)"
|
||||
"-> Tensor");
|
||||
ops.def("machete_prepack_B(Tensor B, int btype) -> Tensor");
|
||||
"machete_supported_schedules("
|
||||
" ScalarType a_type,"
|
||||
" int b_type,"
|
||||
" ScalarType? maybe_group_scales_type,"
|
||||
" ScalarType? maybe_group_zeros_type,"
|
||||
" ScalarType? maybe_channel_scales_type,"
|
||||
" ScalarType? maybe_token_scales_type,"
|
||||
" ScalarType? maybe_out_type"
|
||||
") -> str[]");
|
||||
ops.def(
|
||||
"machete_mm("
|
||||
" Tensor A,"
|
||||
" Tensor B,"
|
||||
" int b_type,"
|
||||
" ScalarType? out_type,"
|
||||
" Tensor? group_scales,"
|
||||
" Tensor? group_zeros,"
|
||||
" int? group_size,"
|
||||
" Tensor? channel_scales,"
|
||||
" Tensor? token_scales,"
|
||||
" str? schedule"
|
||||
") -> Tensor");
|
||||
ops.def(
|
||||
"machete_prepack_B("
|
||||
" Tensor B,"
|
||||
" ScalarType a_type,"
|
||||
" int b_type,"
|
||||
" ScalarType? group_scales_type"
|
||||
") -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
|
@ -1,284 +0,0 @@
|
||||
"""Tests for the machete kernel.
|
||||
|
||||
Run `pytest tests/kernels/test_machete_gemm.py`.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows, quantize_weights)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
MNK_SHAPES = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 1024),
|
||||
(1, 4096, 4096),
|
||||
(1, 8192, 28672),
|
||||
(13, 8192, 4096),
|
||||
(26, 4096, 8192),
|
||||
(64, 4096, 4096),
|
||||
(64, 8192, 28672),
|
||||
(257, 128, 4096),
|
||||
(257, 4224, 4160),
|
||||
(257, 4096, 4096),
|
||||
(1024, 4096, 8192),
|
||||
(1024, 8192, 4096),
|
||||
]
|
||||
|
||||
ACT_TYPES = [torch.float16, torch.bfloat16]
|
||||
WTYPE_ZEROPOINTS = [
|
||||
# GPTQ style
|
||||
(scalar_types.uint4b8, False),
|
||||
(scalar_types.uint8b128, False),
|
||||
# AWQ style
|
||||
(scalar_types.uint4, True),
|
||||
(scalar_types.uint8, True),
|
||||
]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16):
|
||||
return 10 * (torch.rand(shape, dtype=dtype, device="cuda") - 0.3)
|
||||
|
||||
|
||||
def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
|
||||
return zps if zps is None else -1 * s * (zps.to(s.dtype))
|
||||
|
||||
|
||||
def machete_quantize_and_pack(w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
group_size: int,
|
||||
zero_points: bool = False):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
w,
|
||||
wtype,
|
||||
group_size,
|
||||
zero_points=zero_points,
|
||||
# to match how the kernel applies zps
|
||||
ref_zero_points_after_scales=True)
|
||||
|
||||
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||
w_q = w_q.t().contiguous().t() # convert to col major
|
||||
w_q_machete = ops.machete_prepack_B(w_q, wtype)
|
||||
|
||||
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id))
|
||||
|
||||
return w_ref, w_q_machete, w_s, w_zp
|
||||
|
||||
|
||||
def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor,
|
||||
wtype: ScalarType, group_size: int,
|
||||
zero_points: bool):
|
||||
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
||||
b, wtype, group_size, zero_points)
|
||||
|
||||
output_ref = torch.matmul(a, w_ref)
|
||||
|
||||
output = ops.machete_gemm(
|
||||
a=a,
|
||||
b_q=w_q_packed,
|
||||
b_type=wtype,
|
||||
b_scales=w_s,
|
||||
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
|
||||
b_group_size=group_size,
|
||||
)
|
||||
|
||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||
# Relax atol when we have zeropoints since the way machete applies
|
||||
# zeropoints (after scales) causes noise around 0
|
||||
atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1)
|
||||
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("shape",
|
||||
MNK_SHAPES,
|
||||
ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x))
|
||||
@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS)
|
||||
@pytest.mark.parametrize("group_size", [128, None])
|
||||
def test_machete_all_schedules(shape, atype: torch.dtype,
|
||||
wtype_zeropoints: Tuple[ScalarType, bool],
|
||||
group_size: Optional[int]):
|
||||
m, n, k = shape
|
||||
wtype, zero_points = wtype_zeropoints
|
||||
|
||||
if group_size is not None and k % group_size != 0:
|
||||
return
|
||||
|
||||
print(f"MNK = {m} {n} {k}")
|
||||
|
||||
# Normalize group_size
|
||||
if group_size is None:
|
||||
group_size = k
|
||||
assert group_size <= k
|
||||
|
||||
a = rand_data((m, k), atype)
|
||||
w = rand_data((k, n), atype)
|
||||
|
||||
w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack(
|
||||
w, wtype, group_size, zero_points)
|
||||
|
||||
output_ref = torch.matmul(a, w_ref)
|
||||
|
||||
for schedule in ops.machete_supported_schedules(wtype):
|
||||
print(f"Testing schedule {schedule}")
|
||||
output = ops.machete_gemm(
|
||||
a,
|
||||
b_q=w_q_machete,
|
||||
b_type=wtype,
|
||||
b_scales=w_s,
|
||||
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
|
||||
b_group_size=group_size,
|
||||
schedule=schedule,
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.machete_gemm,
|
||||
(a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints(
|
||||
w_zp, w_s), group_size, None, None, None, schedule))
|
||||
|
||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||
# Relax atol when we have zeropoints since the way machete applies
|
||||
# zeropoints (after scales) causes noise around 0
|
||||
atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
|
||||
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\
|
||||
f"Schedule failed {schedule}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("shape",
|
||||
MNK_SHAPES,
|
||||
ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x))
|
||||
@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS)
|
||||
@pytest.mark.parametrize("group_size", [128, None])
|
||||
def test_machete_heuristic(shape, atype: torch.dtype,
|
||||
wtype_zeropoints: Tuple[ScalarType, bool],
|
||||
group_size: Optional[int]):
|
||||
m, n, k = shape
|
||||
wtype, zero_points = wtype_zeropoints
|
||||
|
||||
if group_size is not None and k % group_size != 0:
|
||||
return
|
||||
|
||||
# Normalize group_size
|
||||
if group_size is None:
|
||||
group_size = k
|
||||
assert group_size <= k
|
||||
|
||||
a = rand_data((m, k), atype)
|
||||
b = rand_data((k, n), atype)
|
||||
|
||||
machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
|
||||
|
||||
|
||||
# Test working on other devices
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_machete_devices(device: str):
|
||||
m, n, k = 512, 4096, 4096
|
||||
wtype = scalar_types.uint4b8
|
||||
group_size = 128
|
||||
zero_points = False
|
||||
|
||||
print(f"MNK = {m} {n} {k}, device = {device}")
|
||||
|
||||
a = rand_data((m, k), torch.float16).to(device)
|
||||
b = rand_data((k, n), torch.float16).to(device)
|
||||
|
||||
machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
def test_machete_subset():
|
||||
big_m, big_n, big_k = 1024, 1024, 1024
|
||||
m, n, k = 512, 512, 512
|
||||
wtype = scalar_types.uint4b8
|
||||
group_size = 128
|
||||
zero_points = False
|
||||
|
||||
whole_a = rand_data((big_m, big_k), torch.float16)
|
||||
whole_b = rand_data((big_k, big_n), torch.float16)
|
||||
|
||||
a = whole_a[0:m, 0:k]
|
||||
b = whole_b[0:k, 0:n]
|
||||
|
||||
machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class MacheteLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def forward(self, a):
|
||||
return ops.machete_gemm(**self.kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
def test_machete_cuda_graph():
|
||||
m, n, k = 512, 4096, 4096
|
||||
|
||||
a = rand_data((m, k), torch.float16)
|
||||
b = rand_data((k, n), torch.float16)
|
||||
wtype = scalar_types.uint4b8
|
||||
group_size = 128
|
||||
zero_points = False
|
||||
|
||||
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
||||
b, wtype, group_size, zero_points)
|
||||
|
||||
# Construct a trivial model with a single layer that calls a machete kernel
|
||||
model = MacheteLayer(
|
||||
a=a,
|
||||
b_q=w_q_packed,
|
||||
b_type=wtype,
|
||||
b_scales=w_s,
|
||||
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
|
||||
b_group_size=group_size,
|
||||
)
|
||||
|
||||
output_ref = torch.matmul(a, w_ref)
|
||||
|
||||
# Run the model with a cuda graph
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
output = model(a)
|
||||
output.zero_()
|
||||
g.replay()
|
||||
|
||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||
# Relax atol when we have zeropoints since the way machete applies
|
||||
# zeropoints (after scales) causes noise around 0
|
||||
atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
|
||||
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)
|
406
tests/kernels/test_machete_mm.py
Normal file
406
tests/kernels/test_machete_mm.py
Normal file
@ -0,0 +1,406 @@
|
||||
"""Tests for the machete kernel.
|
||||
|
||||
Run `pytest tests/kernels/test_machete_mm.py`.
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows, quantize_weights)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
|
||||
MNK_SHAPES = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 1024),
|
||||
(1, 4096, 4096),
|
||||
(1, 8192, 28672),
|
||||
(13, 8192, 4096),
|
||||
(26, 4096, 8192),
|
||||
(64, 4096, 4096),
|
||||
(64, 8192, 28672),
|
||||
(257, 128, 4096),
|
||||
(257, 4224, 4160),
|
||||
(257, 4096, 4096),
|
||||
(1024, 4096, 8192),
|
||||
(1024, 8192, 4096),
|
||||
]
|
||||
|
||||
GROUP_SIZES_TO_TEST: List[Optional[int]] = [128, -1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeConfig:
|
||||
act_type: torch.dtype
|
||||
weight_type: ScalarType
|
||||
output_type: Optional[torch.dtype]
|
||||
group_scale_type: Optional[torch.dtype]
|
||||
group_zero_type: Optional[torch.dtype]
|
||||
channel_scale_type: Optional[torch.dtype]
|
||||
token_scale_type: Optional[torch.dtype]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tensors:
|
||||
w_ref: torch.Tensor
|
||||
a_ref: torch.Tensor
|
||||
a: torch.Tensor
|
||||
w_q: torch.Tensor
|
||||
w_g_s: Optional[torch.Tensor]
|
||||
w_g_zp: Optional[torch.Tensor]
|
||||
w_ch_s: Optional[torch.Tensor]
|
||||
w_tok_s: Optional[torch.Tensor]
|
||||
|
||||
|
||||
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
||||
# Ch Scales Type, Tok Scales Type)
|
||||
# NOTE: None "Scale Type" means the act type is floating point
|
||||
# None "Output Type" means the output type is the same as the act type
|
||||
TestTypeTuple = Tuple[List[torch.dtype], ScalarType, Optional[torch.dtype],
|
||||
Optional[torch.dtype], bool]
|
||||
TEST_TYPES = [
|
||||
# GPTQ style
|
||||
*(TypeConfig(act_type=a_type,
|
||||
weight_type=w_type,
|
||||
output_type=None,
|
||||
group_scale_type=a_type,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None)
|
||||
for w_type in [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
for a_type in [torch.float16, torch.bfloat16]),
|
||||
# AWQ style
|
||||
*(TypeConfig(act_type=a_type,
|
||||
weight_type=w_type,
|
||||
output_type=None,
|
||||
group_scale_type=a_type,
|
||||
group_zero_type=a_type,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None)
|
||||
for w_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
for a_type in [torch.float16, torch.bfloat16]),
|
||||
# QQQ style
|
||||
*(TypeConfig(act_type=torch.int8,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=torch.float16,
|
||||
group_scale_type=group_scale_type,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=torch.float,
|
||||
token_scale_type=torch.float)
|
||||
for group_scale_type in [None, torch.float16]),
|
||||
*(TypeConfig(act_type=torch.float8_e4m3fn,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=torch.float16,
|
||||
group_scale_type=group_scale_type,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=torch.float,
|
||||
token_scale_type=torch.float)
|
||||
for group_scale_type in [None, torch.float16]),
|
||||
]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16, scale=1, offset=0):
|
||||
if dtype.is_floating_point:
|
||||
return (scale * torch.rand(shape, device="cuda") - offset).to(dtype)
|
||||
else:
|
||||
return torch.randint(-8, 7, shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
|
||||
return zps if zps is None else -1 * s * (zps.to(s.dtype))
|
||||
|
||||
|
||||
def group_size_valid(shape: Tuple[int, int, int],
|
||||
group_size: Optional[int]) -> bool:
|
||||
return group_size is None or group_size == -1 or group_size % shape[2] == 0
|
||||
|
||||
|
||||
def machete_quantize_and_pack(atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
w,
|
||||
wtype,
|
||||
group_size=group_size,
|
||||
zero_points=zero_points,
|
||||
# to match how the kernel applies zps
|
||||
ref_zero_points_after_scales=True)
|
||||
|
||||
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||
w_q = w_q.t().contiguous().t() # convert to col major
|
||||
|
||||
w_q_machete = ops.machete_prepack_B(w_q, atype, wtype, stype)
|
||||
opcheck(torch.ops._C.machete_prepack_B, (w_q, atype, wtype.id, stype))
|
||||
|
||||
return w_ref, w_q_machete, w_s, w_zp
|
||||
|
||||
|
||||
def create_test_tensors(shape: Tuple[int, int, int],
|
||||
types: TypeConfig,
|
||||
group_size: Optional[int],
|
||||
subset_stride_factor: Optional[int] = None) -> Tensors:
|
||||
m, n, k = shape
|
||||
factor = subset_stride_factor or 1
|
||||
|
||||
print("create_test_tensors, shape:", shape, "types:", types, "group_size:",
|
||||
group_size)
|
||||
|
||||
a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2)
|
||||
w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1)
|
||||
|
||||
if factor > 1:
|
||||
a = a[0:m, 0:k]
|
||||
w = w[0:k, 0:n]
|
||||
|
||||
if types.group_scale_type is not None:
|
||||
w = w.to(types.group_scale_type)
|
||||
if w.dtype.itemsize == 1:
|
||||
w = w.to(torch.float16)
|
||||
|
||||
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
||||
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
|
||||
types.group_zero_type is not None)
|
||||
|
||||
if not a.dtype.is_floating_point:
|
||||
aiinfo = torch.iinfo(a.dtype)
|
||||
w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max)
|
||||
|
||||
a_ref = a.to(torch.float32)
|
||||
w_ref = w_ref.to(torch.float32)
|
||||
|
||||
w_ch_s = None if types.channel_scale_type is None else\
|
||||
rand_data((n,), types.channel_scale_type)
|
||||
w_tok_s = None if types.token_scale_type is None else\
|
||||
rand_data((m,), types.token_scale_type)
|
||||
|
||||
return Tensors(w_ref=w_ref,
|
||||
a_ref=a_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
w_g_s=w_s,
|
||||
w_g_zp=maybe_convert_zeropoints(w_zp, w_s),
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s)
|
||||
|
||||
|
||||
# None stype means scales use the same dtype as a
|
||||
def machete_mm_test_helper(types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: Optional[int] = None,
|
||||
schedule: Optional[str] = None):
|
||||
output_ref = torch.matmul(tensors.a_ref, tensors.w_ref)
|
||||
output_ref_type = output_ref.dtype
|
||||
|
||||
if tensors.w_ch_s is not None:
|
||||
output_ref = (output_ref.to(tensors.w_ch_s.dtype) *
|
||||
tensors.w_ch_s.unsqueeze(0)).to(output_ref_type)
|
||||
if tensors.w_tok_s is not None:
|
||||
output_ref = (output_ref.to(tensors.w_tok_s.dtype) *
|
||||
tensors.w_tok_s.unsqueeze(1)).to(output_ref_type)
|
||||
|
||||
output = ops.machete_mm(
|
||||
a=tensors.a,
|
||||
b_q=tensors.w_q,
|
||||
b_type=types.weight_type,
|
||||
b_group_scales=tensors.w_g_s,
|
||||
b_group_zeros=tensors.w_g_zp,
|
||||
b_group_size=group_size,
|
||||
b_channel_scales=tensors.w_ch_s,
|
||||
a_token_scales=tensors.w_tok_s,
|
||||
out_type=types.output_type,
|
||||
schedule=schedule,
|
||||
)
|
||||
|
||||
print(output)
|
||||
print(output_ref)
|
||||
|
||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||
# Relax atol when we have zeropoints since the way machete applies
|
||||
# zeropoints (after scales) causes noise around 0
|
||||
atol = 1 if tensors.w_g_zp is not None\
|
||||
else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1)
|
||||
rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1
|
||||
torch.testing.assert_close(output,
|
||||
output_ref.to(output.dtype),
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("shape",
|
||||
MNK_SHAPES,
|
||||
ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
def test_machete_all_schedules(shape, types: TypeConfig):
|
||||
|
||||
group_sizes: List[Optional[int]] = []
|
||||
if types.group_scale_type is None:
|
||||
group_sizes = [None]
|
||||
else:
|
||||
group_sizes = GROUP_SIZES_TO_TEST
|
||||
|
||||
for group_size in group_sizes:
|
||||
if not group_size_valid(shape, group_size):
|
||||
continue
|
||||
|
||||
tensors = create_test_tensors(shape, types, group_size)
|
||||
print(f"MNK = {shape}")
|
||||
for schedule in ops.machete_supported_schedules(
|
||||
types.act_type,
|
||||
types.weight_type,
|
||||
group_scales_type=types.group_scale_type,
|
||||
group_zeros_type=types.group_scale_type,
|
||||
out_type=types.output_type):
|
||||
print(f"Testing schedule {schedule}")
|
||||
machete_mm_test_helper(types, tensors, group_size, schedule)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("shape",
|
||||
MNK_SHAPES,
|
||||
ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
def test_machete_heuristic(shape, types: TypeConfig):
|
||||
group_sizes: List[Optional[int]] = []
|
||||
if types.group_scale_type is None:
|
||||
group_sizes = [None]
|
||||
else:
|
||||
group_sizes = GROUP_SIZES_TO_TEST
|
||||
|
||||
for group_size in group_sizes:
|
||||
if not group_size_valid(shape, group_size):
|
||||
continue
|
||||
|
||||
tensors = create_test_tensors(shape, types, group_size)
|
||||
machete_mm_test_helper(types, tensors, group_size)
|
||||
|
||||
|
||||
# Test working on other devices
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_machete_devices(device: str):
|
||||
group_size = 128
|
||||
|
||||
type_config = TypeConfig(act_type=torch.float16,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=None,
|
||||
group_scale_type=torch.float16,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None)
|
||||
|
||||
tensors = create_test_tensors((512, 4096, 4096), type_config, group_size)
|
||||
|
||||
for field in fields(Tensors):
|
||||
tensor = getattr(tensors, field.name)
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
setattr(tensors, field.name, tensor.to(device))
|
||||
|
||||
machete_mm_test_helper(type_config, tensors, group_size)
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
def test_machete_subset():
|
||||
group_size = 128
|
||||
|
||||
type_config = TypeConfig(act_type=torch.float16,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=None,
|
||||
group_scale_type=torch.float16,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None)
|
||||
|
||||
tensors = create_test_tensors((512, 4096, 4096),
|
||||
type_config,
|
||||
group_size,
|
||||
subset_stride_factor=2)
|
||||
machete_mm_test_helper(type_config, tensors, group_size)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class MacheteLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def forward(self, a):
|
||||
return ops.machete_mm(a=a, **self.kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
def test_machete_cuda_graph():
|
||||
m, n, k = 512, 4096, 4096
|
||||
|
||||
a = rand_data((m, k), torch.float16)
|
||||
b = rand_data((k, n), torch.float16)
|
||||
wtype = scalar_types.uint4b8
|
||||
stype = torch.float16
|
||||
group_size = 128
|
||||
zero_points = False
|
||||
|
||||
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
||||
a.dtype, b, wtype, stype, group_size, zero_points)
|
||||
|
||||
# Construct a trivial model with a single layer that calls a machete kernel
|
||||
model = MacheteLayer(
|
||||
b_q=w_q_packed,
|
||||
b_type=wtype,
|
||||
b_group_scales=w_s,
|
||||
b_group_zeros=maybe_convert_zeropoints(w_zp, w_s),
|
||||
b_group_size=group_size,
|
||||
)
|
||||
|
||||
output_ref = torch.matmul(a, w_ref)
|
||||
|
||||
# Run the model with a cuda graph
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
output = model(a)
|
||||
output.zero_()
|
||||
g.replay()
|
||||
|
||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||
# Relax atol when we have zeropoints since the way machete applies
|
||||
# zeropoints (after scales) causes noise around 0
|
||||
atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
|
||||
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)
|
@ -444,18 +444,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
size_k: torch.SymInt) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
||||
|
||||
@register_fake("_C::machete_gemm")
|
||||
def machete_gemm_fake(
|
||||
@register_fake("_C::machete_mm")
|
||||
def machete_mm_fake(
|
||||
a: torch.Tensor,
|
||||
# Should be the tensor returned by machete_prepack_B
|
||||
# b_q Should be the tensor returned by machete_prepack_B
|
||||
b_q: torch.Tensor,
|
||||
b_type: ScalarType,
|
||||
b_scales: Optional[torch.Tensor] = None,
|
||||
b_zeros: Optional[torch.Tensor] = None,
|
||||
out_type: Optional[torch.dtype] = None,
|
||||
b_group_scales: Optional[torch.Tensor] = None,
|
||||
b_group_zeros: Optional[torch.Tensor] = None,
|
||||
b_group_size: Optional[int] = None,
|
||||
c: Optional[torch.Tensor] = None,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
b_channel_scales: Optional[torch.Tensor] = None,
|
||||
a_token_scales: Optional[torch.Tensor] = None,
|
||||
schedule: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
m = a.size(0)
|
||||
@ -463,8 +463,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
return torch.empty((m, n), device=a.device, dtype=a.dtype)
|
||||
|
||||
@register_fake("_C::machete_prepack_B")
|
||||
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
|
||||
b_type: ScalarType) -> torch.Tensor:
|
||||
def machete_prepack_B_fake(
|
||||
b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
|
||||
group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
|
||||
return torch.empty_like(b_q_weight,
|
||||
memory_format=torch.contiguous_format)
|
||||
|
||||
@ -617,29 +618,41 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
|
||||
|
||||
# machete
|
||||
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
|
||||
return torch.ops._C.machete_supported_schedules(b_type.id)
|
||||
def machete_supported_schedules(
|
||||
a_type: torch.dtype,
|
||||
b_type: ScalarType,
|
||||
group_scales_type: Optional[torch.dtype],
|
||||
group_zeros_type: Optional[torch.dtype] = None,
|
||||
channel_scales_type: Optional[torch.dtype] = None,
|
||||
token_scales_type: Optional[torch.dtype] = None,
|
||||
out_type: Optional[torch.dtype] = None) -> List[str]:
|
||||
return torch.ops._C.machete_supported_schedules(
|
||||
a_type, b_type.id, group_scales_type, group_zeros_type,
|
||||
channel_scales_type, token_scales_type, out_type)
|
||||
|
||||
|
||||
def machete_gemm(
|
||||
a: torch.Tensor,
|
||||
b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B
|
||||
b_type: ScalarType,
|
||||
b_scales: Optional[torch.Tensor] = None,
|
||||
b_zeros: Optional[torch.Tensor] = None,
|
||||
b_group_size: Optional[int] = None,
|
||||
c: Optional[torch.Tensor] = None,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
schedule: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros,
|
||||
b_group_size, c, alpha, beta, schedule)
|
||||
def machete_mm(
|
||||
a: torch.Tensor,
|
||||
# b_q Should be the tensor returned by machete_prepack_B
|
||||
b_q: torch.Tensor,
|
||||
b_type: ScalarType,
|
||||
out_type: Optional[torch.dtype] = None,
|
||||
b_group_scales: Optional[torch.Tensor] = None,
|
||||
b_group_zeros: Optional[torch.Tensor] = None,
|
||||
b_group_size: Optional[int] = None,
|
||||
b_channel_scales: Optional[torch.Tensor] = None,
|
||||
a_token_scales: Optional[torch.Tensor] = None,
|
||||
schedule: Optional[str] = None) -> torch.Tensor:
|
||||
return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales,
|
||||
b_group_zeros, b_group_size,
|
||||
b_channel_scales, a_token_scales, schedule)
|
||||
|
||||
|
||||
def machete_prepack_B(b_q_weight: torch.Tensor,
|
||||
b_type: ScalarType) -> torch.Tensor:
|
||||
return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id)
|
||||
def machete_prepack_B(
|
||||
b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
|
||||
group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
|
||||
return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id,
|
||||
group_scales_type)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "permute_cols"):
|
||||
|
@ -79,7 +79,9 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
c.weight_type,
|
||||
packed_dim=0)
|
||||
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
|
||||
self.config.weight_type)
|
||||
a_type=c.act_type,
|
||||
b_type=c.weight_type,
|
||||
group_scales_type=c.act_type)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
@ -105,12 +107,12 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
if c.has_g_idx:
|
||||
x_2d = self.act_perm(x_2d)
|
||||
|
||||
output = ops.machete_gemm(a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_zeros=None,
|
||||
b_scales=w_s,
|
||||
b_group_size=c.group_size)
|
||||
output = ops.machete_mm(a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_group_zeros=None,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
@ -126,11 +126,14 @@ def permute_rows(q_w: torch.Tensor,
|
||||
|
||||
def quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False,
|
||||
ref_zero_points_after_scales: bool = False):
|
||||
assert quant_type.is_integer(), \
|
||||
"Floating point quantization may work but has not been tested"
|
||||
assert not zero_points or group_size is not None, \
|
||||
"to have group zero points, group_size must be provided "\
|
||||
"(-1 group_size is channelwise)"
|
||||
|
||||
orig_device = w.device
|
||||
orig_type = w.dtype
|
||||
@ -140,10 +143,9 @@ def quantize_weights(w: torch.Tensor,
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Reshape to [groupsize, -1]
|
||||
if group_size < size_k:
|
||||
if group_size is not None and group_size < size_k:
|
||||
w = w.reshape((-1, group_size, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((group_size, -1))
|
||||
@ -155,18 +157,20 @@ def quantize_weights(w: torch.Tensor,
|
||||
max_q_val = quant_type.max()
|
||||
min_q_val = quant_type.min()
|
||||
|
||||
if zero_points:
|
||||
assert not quant_type.is_signed() and quant_type.max() > 0
|
||||
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
||||
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
|
||||
.clamp(min_q_val, max_q_val).int()
|
||||
else:
|
||||
# If the bias is such that there are no possible negative/positive
|
||||
# values, set the max value to inf to avoid divide by 0
|
||||
w_s = torch.max(
|
||||
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
||||
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
|
||||
maybe_w_zp = None
|
||||
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
||||
maybe_w_zp = None
|
||||
if group_size is not None:
|
||||
if zero_points:
|
||||
assert not quant_type.is_signed() and quant_type.max() > 0
|
||||
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
||||
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
|
||||
.clamp(min_q_val, max_q_val).int()
|
||||
else:
|
||||
# If the bias is such that there are no possible negative/positive
|
||||
# values, set the max value to inf to avoid divide by 0
|
||||
w_s = torch.max(
|
||||
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
||||
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
|
||||
|
||||
# Quantize
|
||||
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
||||
@ -176,7 +180,7 @@ def quantize_weights(w: torch.Tensor,
|
||||
# For some kernels (namely Machete) the zero-points are applied after the
|
||||
# scales are applied, for this case computing the reference in similar way
|
||||
# allows us to use tighter error tolerances in our unit tests.
|
||||
if ref_zero_points_after_scales and zero_points:
|
||||
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
||||
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
||||
else:
|
||||
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
||||
@ -185,7 +189,7 @@ def quantize_weights(w: torch.Tensor,
|
||||
w_q += quant_type.bias
|
||||
|
||||
# Restore original shapes
|
||||
if group_size < size_k:
|
||||
if group_size is not None and group_size < size_k:
|
||||
|
||||
def reshape_w(w):
|
||||
w = w.reshape((group_size, -1, size_n))
|
||||
@ -195,17 +199,16 @@ def quantize_weights(w: torch.Tensor,
|
||||
|
||||
w_q = reshape_w(w_q)
|
||||
w_ref = reshape_w(w_ref)
|
||||
w_s = w_s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
w_s = w_s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
if zero_points:
|
||||
if maybe_w_zp is not None:
|
||||
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
||||
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
w_q.to(device=orig_device),
|
||||
w_s.to(device=orig_device),
|
||||
w_s if group_size is not None else None,
|
||||
maybe_w_zp,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user