[Kernel] Initial Machete W4A8 support + Refactors (#9855)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
c2170a5b39
commit
96d999fbe8
@ -2,8 +2,10 @@ import argparse
|
|||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import pickle as pkl
|
import pickle as pkl
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Callable, Iterable, List, Optional, Tuple
|
from typing import Callable, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
@ -15,11 +17,12 @@ from weight_shapes import WEIGHT_SHAPES
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales)
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales,
|
||||||
|
marlin_zero_points)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
MarlinWorkspace)
|
MarlinWorkspace)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
gptq_pack, pack_rows, quantize_weights)
|
pack_rows, quantize_weights)
|
||||||
from vllm.scalar_type import ScalarType, scalar_types
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
@ -27,149 +30,349 @@ DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
|
|||||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
|
||||||
DEFAULT_TP_SIZES = [1]
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
|
NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False)
|
||||||
|
|
||||||
def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor:
|
if NVTX_PROFILE:
|
||||||
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
import nvtx
|
||||||
w_q = w_q.t().contiguous().t() # make col major
|
|
||||||
return ops.machete_prepack_B(w_q, wtype)
|
|
||||||
|
|
||||||
|
|
||||||
def make_bench_tensors(
|
def terse_type_name(dt):
|
||||||
atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int,
|
return {
|
||||||
k: int
|
torch.bfloat16: "bf16",
|
||||||
) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor,
|
torch.float16: "fp16",
|
||||||
torch.tensor]]]:
|
torch.int8: "int8",
|
||||||
|
torch.float8_e4m3fn: "fp8",
|
||||||
|
torch.bfloat16: "bf16",
|
||||||
|
torch.float: "float",
|
||||||
|
torch.int: "int",
|
||||||
|
}[dt]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkTensors:
|
||||||
|
w_ref: torch.Tensor
|
||||||
|
a: torch.Tensor
|
||||||
|
|
||||||
|
w_q: torch.Tensor
|
||||||
|
group_size: Optional[int]
|
||||||
|
wtype: ScalarType
|
||||||
|
w_g_s: torch.Tensor
|
||||||
|
w_g_zp: Optional[torch.Tensor]
|
||||||
|
w_ch_s: Optional[torch.Tensor]
|
||||||
|
w_tok_s: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TypeConfig:
|
||||||
|
act_type: torch.dtype
|
||||||
|
weight_type: ScalarType
|
||||||
|
output_type: Optional[torch.dtype]
|
||||||
|
group_scale_type: Optional[torch.dtype]
|
||||||
|
group_zero_type: Optional[torch.dtype]
|
||||||
|
channel_scale_type: Optional[torch.dtype]
|
||||||
|
token_scale_type: Optional[torch.dtype]
|
||||||
|
|
||||||
|
|
||||||
|
def rand_data(shape, dtype=torch.float16, scale=1):
|
||||||
|
if dtype.is_floating_point:
|
||||||
|
return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype)
|
||||||
|
else:
|
||||||
|
return torch.randint(-15, 15, shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_and_pack(atype: torch.dtype,
|
||||||
|
w: torch.Tensor,
|
||||||
|
wtype: ScalarType,
|
||||||
|
stype: Optional[torch.dtype],
|
||||||
|
group_size: Optional[int],
|
||||||
|
zero_points: bool = False):
|
||||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||||
|
|
||||||
|
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||||
|
w,
|
||||||
|
wtype,
|
||||||
|
group_size=group_size,
|
||||||
|
zero_points=zero_points,
|
||||||
|
# to match how the kernel applies zps
|
||||||
|
ref_zero_points_after_scales=True)
|
||||||
|
|
||||||
|
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||||
|
return w_ref, w_q, w_s, w_zp
|
||||||
|
|
||||||
|
|
||||||
|
def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig,
|
||||||
|
group_size: Optional[int]) -> List[BenchmarkTensors]:
|
||||||
|
m, n, k = shape
|
||||||
|
|
||||||
# we want to make sure that weights don't fit into L2 cache between runs so
|
# we want to make sure that weights don't fit into L2 cache between runs so
|
||||||
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
|
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
|
||||||
# so we target total weight size > 2*50mb
|
# so we target total weight size > 2*50mb
|
||||||
num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits))
|
num_weights = math.ceil(2 * 50 * 1024**2 * 8 /
|
||||||
|
(k * n * types.weight_type.size_bits))
|
||||||
|
|
||||||
a = torch.randn((m, k), device="cuda", dtype=atype) * 5
|
a = rand_data((m, k), types.act_type, scale=5)
|
||||||
weights = [
|
|
||||||
torch.randn((k, n), device="cuda", dtype=atype)
|
|
||||||
for _ in range(num_weights)
|
|
||||||
]
|
|
||||||
quanitized_weights = [
|
|
||||||
quantize_weights(w, wtype, group_size) for w in weights
|
|
||||||
]
|
|
||||||
|
|
||||||
return a, quanitized_weights
|
benchmark_tensors: List[BenchmarkTensors] = []
|
||||||
|
for _ in range(num_weights):
|
||||||
|
w = rand_data((k, n), types.act_type, scale=5)
|
||||||
|
|
||||||
|
if types.group_scale_type is not None:
|
||||||
|
w = w.to(types.group_scale_type)
|
||||||
|
if w.dtype.itemsize == 1:
|
||||||
|
w = w.to(torch.float16)
|
||||||
|
|
||||||
|
w_ref, w_q_packed, w_s, w_zp = quantize_and_pack(
|
||||||
|
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
|
||||||
|
types.group_zero_type is not None)
|
||||||
|
|
||||||
|
if not a.dtype.is_floating_point:
|
||||||
|
aiinfo = torch.iinfo(a.dtype)
|
||||||
|
w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max)
|
||||||
|
|
||||||
|
w_ref = w_ref.to(torch.float32)
|
||||||
|
|
||||||
|
w_ch_s = None if types.channel_scale_type is None else\
|
||||||
|
rand_data((n,), types.channel_scale_type)
|
||||||
|
w_tok_s = None if types.token_scale_type is None else\
|
||||||
|
rand_data((m,), types.token_scale_type)
|
||||||
|
|
||||||
|
benchmark_tensors.append(
|
||||||
|
BenchmarkTensors(w_ref=w_ref,
|
||||||
|
a=a,
|
||||||
|
w_q=w_q_packed,
|
||||||
|
wtype=types.weight_type,
|
||||||
|
w_g_s=w_s,
|
||||||
|
w_g_zp=w_zp,
|
||||||
|
group_size=group_size,
|
||||||
|
w_ch_s=w_ch_s,
|
||||||
|
w_tok_s=w_tok_s))
|
||||||
|
|
||||||
|
return benchmark_tensors
|
||||||
|
|
||||||
|
|
||||||
|
def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||||
|
a = bt.a
|
||||||
|
w = bt.w_ref.to(bt.a.dtype) # use float reference tensor
|
||||||
|
if a.dtype not in [torch.float16, torch.bfloat16]:
|
||||||
|
a = a.to(torch.float16)
|
||||||
|
w = w.to(torch.float16)
|
||||||
|
return lambda: torch.matmul(a, w)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||||
|
if bt.w_ch_s is not None and bt.w_tok_s is not None:
|
||||||
|
scale_a = bt.w_tok_s.to(torch.float32)
|
||||||
|
scale_b = bt.w_ch_s.to(torch.float32)
|
||||||
|
else:
|
||||||
|
scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
|
||||||
|
scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
|
||||||
|
w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t()
|
||||||
|
return lambda: ops.cutlass_scaled_mm(
|
||||||
|
bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16)
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||||
|
device = bt.a.device
|
||||||
|
|
||||||
|
workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL)
|
||||||
|
|
||||||
|
if bt.w_g_zp is None:
|
||||||
|
w_zp = torch.empty(0, dtype=torch.int, device=device)
|
||||||
|
else:
|
||||||
|
w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0],
|
||||||
|
bt.w_ref.shape[1], bt.wtype.size_bits)
|
||||||
|
|
||||||
|
if bt.group_size is None:
|
||||||
|
w_s = torch.tensor([], device="cuda", dtype=torch.half)
|
||||||
|
else:
|
||||||
|
w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0],
|
||||||
|
bt.w_ref.shape[1], bt.group_size)
|
||||||
|
|
||||||
|
sort_indices = torch.empty(0, dtype=torch.int, device=device)
|
||||||
|
g_idx = torch.empty(0, dtype=torch.int, device=device)
|
||||||
|
w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0],
|
||||||
|
bt.w_ref.shape[1], bt.wtype.size_bits)
|
||||||
|
|
||||||
|
if bt.a.dtype.is_floating_point:
|
||||||
|
assert bt.w_ch_s is None
|
||||||
|
assert bt.w_tok_s is None
|
||||||
|
assert bt.group_size is not None
|
||||||
|
|
||||||
|
fn = lambda: ops.gptq_marlin_gemm(a=bt.a,
|
||||||
|
b_q_weight=w_q,
|
||||||
|
b_scales=w_s,
|
||||||
|
b_zeros=w_zp,
|
||||||
|
g_idx=g_idx,
|
||||||
|
perm=sort_indices,
|
||||||
|
workspace=workspace.scratch,
|
||||||
|
b_q_type=bt.wtype,
|
||||||
|
size_m=bt.a.shape[0],
|
||||||
|
size_n=bt.w_ref.shape[1],
|
||||||
|
size_k=bt.w_ref.shape[0],
|
||||||
|
is_k_full=True)
|
||||||
|
else:
|
||||||
|
assert bt.a.dtype == torch.int8
|
||||||
|
assert bt.wtype == scalar_types.uint4b8
|
||||||
|
|
||||||
|
if bt.w_ch_s is not None:
|
||||||
|
s_ch = bt.w_ch_s.to(torch.float32)
|
||||||
|
else:
|
||||||
|
s_ch = torch.ones(bt.w_ref.shape[1],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
if bt.w_tok_s is not None:
|
||||||
|
s_tok = bt.w_tok_s.to(torch.float32)
|
||||||
|
else:
|
||||||
|
s_tok = torch.ones(bt.a.shape[0],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
fn = lambda: ops.marlin_qqq_gemm(a=bt.a,
|
||||||
|
b_q_weight=w_q,
|
||||||
|
s_group=w_s,
|
||||||
|
s_tok=s_tok,
|
||||||
|
s_ch=s_ch,
|
||||||
|
workspace=workspace.scratch,
|
||||||
|
size_m=bt.a.shape[0],
|
||||||
|
size_n=bt.w_ref.shape[1],
|
||||||
|
size_k=bt.w_ref.shape[0])
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def machete_create_bench_fn(bt: BenchmarkTensors,
|
||||||
|
out_type=torch.dtype,
|
||||||
|
schedule=None) -> Callable:
|
||||||
|
w_q = bt.w_q.t().contiguous().t() # make col major
|
||||||
|
w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype,
|
||||||
|
None if bt.w_g_s is None else bt.w_g_s.dtype)
|
||||||
|
|
||||||
|
w_g_zp = bt.w_g_zp
|
||||||
|
if w_g_zp is not None:
|
||||||
|
w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype))
|
||||||
|
|
||||||
|
return lambda: ops.machete_mm(
|
||||||
|
a=bt.a,
|
||||||
|
b_q=bt.w_q,
|
||||||
|
b_type=bt.wtype,
|
||||||
|
b_group_scales=bt.w_g_s,
|
||||||
|
b_group_zeros=w_g_zp,
|
||||||
|
b_group_size=bt.group_size,
|
||||||
|
b_channel_scales=bt.w_ch_s,
|
||||||
|
a_token_scales=bt.w_tok_s,
|
||||||
|
out_type=out_type,
|
||||||
|
schedule=schedule,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# impl
|
# impl
|
||||||
|
|
||||||
|
|
||||||
# bench
|
# bench
|
||||||
def bench_fn(label: str, sub_label: str, description: str,
|
|
||||||
fn: Callable) -> TMeasurement:
|
|
||||||
|
|
||||||
min_run_time = 1
|
|
||||||
return TBenchmark.Timer(
|
def bench_fns(label: str, sub_label: str, description: str,
|
||||||
stmt="fn()",
|
fns: List[Callable]):
|
||||||
|
|
||||||
|
min_run_time = 1 if not NVTX_PROFILE else 0.1
|
||||||
|
res = TBenchmark.Timer(
|
||||||
|
stmt="""
|
||||||
|
for fn in fns:
|
||||||
|
fn()
|
||||||
|
""",
|
||||||
globals={
|
globals={
|
||||||
"fn": fn
|
"fns": fns
|
||||||
},
|
},
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
description=description,
|
description=description,
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
|
if NVTX_PROFILE:
|
||||||
|
with nvtx.annotate("mm-bench"), nvtx.annotate(
|
||||||
|
f"{label}|{sub_label}|{description}"):
|
||||||
|
fns[0]()
|
||||||
|
|
||||||
def loop_over_weights(
|
return res
|
||||||
a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor,
|
|
||||||
torch.tensor, torch.tensor]],
|
|
||||||
fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor],
|
|
||||||
None]):
|
|
||||||
for w_ref, w_q, w_s, _ in weights:
|
|
||||||
fn(a, w_ref, w_q, w_s)
|
|
||||||
|
|
||||||
|
|
||||||
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
|
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
|
||||||
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
|
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def bench(atype: torch.dtype,
|
def bench(types: TypeConfig,
|
||||||
wtype: ScalarType,
|
|
||||||
group_size: int,
|
group_size: int,
|
||||||
m: int,
|
m: int,
|
||||||
k: int,
|
k: int,
|
||||||
n: int,
|
n: int,
|
||||||
label: str,
|
label: str,
|
||||||
sub_label: str,
|
sub_label: str,
|
||||||
benchmark_marlinv1: bool = True,
|
sweep_schedules: bool = True) -> List[TMeasurement]:
|
||||||
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
|
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
|
||||||
global _SWEEP_SCHEDULES_RESULTS
|
sub_label += f", L={len(benchmark_tensors)}"
|
||||||
|
|
||||||
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
|
name_type_string = f"W{types.weight_type}"+\
|
||||||
sub_label += f", L={len(weights)}"
|
f"-A{terse_type_name(types.act_type)}"
|
||||||
|
if types.group_scale_type is not None:
|
||||||
weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp)
|
name_type_string += f"-GS{terse_type_name(types.group_scale_type)}"
|
||||||
for w_ref, w_q, w_s, w_zp in weights]
|
if types.group_zero_type is not None:
|
||||||
|
name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}"
|
||||||
|
if group_size is not None:
|
||||||
|
name_type_string += f"-G{group_size}"
|
||||||
|
if types.channel_scale_type is not None:
|
||||||
|
name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}"
|
||||||
|
if types.token_scale_type is not None:
|
||||||
|
name_type_string += f"-TS{terse_type_name(types.token_scale_type)}"
|
||||||
|
|
||||||
timers = []
|
timers = []
|
||||||
# pytorch impl
|
# pytorch impl
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fns(
|
||||||
label, sub_label, "torch.matmul", lambda: loop_over_weights(
|
label, sub_label, "torch.matmul (fp16)",
|
||||||
a,
|
[torch_matmul_f16_create_bench_fn(bt)
|
||||||
weights,
|
for bt in benchmark_tensors]))
|
||||||
lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref),
|
|
||||||
)))
|
|
||||||
|
|
||||||
if benchmark_marlinv1:
|
if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
|
||||||
w_ref = weights[0][0]
|
|
||||||
|
|
||||||
w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device)
|
|
||||||
sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device)
|
|
||||||
g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device)
|
|
||||||
|
|
||||||
def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor:
|
|
||||||
w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape)
|
|
||||||
return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape,
|
|
||||||
wtype.size_bits)
|
|
||||||
|
|
||||||
def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
|
|
||||||
return marlin_permute_scales(w_s, *w_ref.shape, group_size)
|
|
||||||
|
|
||||||
weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q),
|
|
||||||
marlinv1_permute_scales(w_s), w_zp)
|
|
||||||
for w_ref, w_q, w_s, w_zp in weights]
|
|
||||||
|
|
||||||
workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
|
|
||||||
GPTQ_MARLIN_MAX_PARALLEL)
|
|
||||||
|
|
||||||
# marlinv1
|
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fns(
|
||||||
label, sub_label, "marlin_orig", lambda: loop_over_weights(
|
label, sub_label,
|
||||||
a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops.
|
f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [
|
||||||
gptq_marlin_gemm(a,
|
cutlass_scaled_mm_create_bench_fn(bt)
|
||||||
w_q,
|
for bt in benchmark_tensors
|
||||||
w_s,
|
]))
|
||||||
w_zp_empty,
|
|
||||||
g_idx,
|
if types.act_type != torch.float8_e4m3fn:
|
||||||
sort_indices,
|
timers.append(
|
||||||
workspace.scratch,
|
bench_fns(label, sub_label, f"marlin ({name_type_string})",
|
||||||
wtype,
|
[marlin_create_bench_fn(bt)
|
||||||
size_m=a.shape[0],
|
for bt in benchmark_tensors]))
|
||||||
size_n=w_ref.shape[1],
|
|
||||||
size_k=w_ref.shape[0],
|
|
||||||
is_k_full=True))))
|
|
||||||
|
|
||||||
# machete
|
# machete
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fns(label, sub_label, f"machete ({name_type_string})", [
|
||||||
label, sub_label, "machete_heuristic", lambda: loop_over_weights(
|
machete_create_bench_fn(bt, out_type=types.output_type)
|
||||||
a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm(
|
for bt in benchmark_tensors
|
||||||
a, w_q, wtype, b_scales=w_s, b_group_size=group_size))))
|
]))
|
||||||
|
|
||||||
if sweep_schedules:
|
if sweep_schedules:
|
||||||
|
global _SWEEP_SCHEDULES_RESULTS
|
||||||
|
|
||||||
print("Finding best schedule for machete")
|
print("Finding best schedule for machete")
|
||||||
best = None
|
best = None
|
||||||
best_schedule = None
|
best_schedule = None
|
||||||
schedules = ops.machete_supported_schedules(wtype)
|
schedules = ops.machete_supported_schedules(
|
||||||
|
a_type=types.act_type,
|
||||||
|
b_type=types.weight_type,
|
||||||
|
group_scales_type=types.group_scale_type,
|
||||||
|
group_zeros_type=types.group_zero_type,
|
||||||
|
token_scales_type=types.token_scale_type,
|
||||||
|
channel_scales_type=types.channel_scale_type,
|
||||||
|
out_type=types.output_type)
|
||||||
|
|
||||||
|
if schedules is None or len(schedules) == 0:
|
||||||
|
raise ValueError("No schedules found to sweep")
|
||||||
|
|
||||||
for schedule in reversed(schedules):
|
for schedule in reversed(schedules):
|
||||||
schedule_M = int(schedule.split("_")[0].split("x")[1])
|
schedule_M = int(schedule.split("_")[0].split("x")[1])
|
||||||
|
|
||||||
@ -177,16 +380,11 @@ def bench(atype: torch.dtype,
|
|||||||
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
|
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def run(a, _, w_q, w_s, schedule=schedule):
|
res = bench_fns(label, sub_label, "machete_best", [
|
||||||
ops.machete_gemm(a,
|
machete_create_bench_fn(
|
||||||
w_q,
|
bt, out_type=types.output_type, schedule=schedule)
|
||||||
wtype,
|
for bt in benchmark_tensors
|
||||||
w_s,
|
])
|
||||||
b_group_size=group_size,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
res = bench_fn(label, sub_label, "machete_best",
|
|
||||||
lambda: loop_over_weights(a, weights_machete, run))
|
|
||||||
|
|
||||||
results_row = {
|
results_row = {
|
||||||
"M": m,
|
"M": m,
|
||||||
@ -213,25 +411,33 @@ def bench(atype: torch.dtype,
|
|||||||
|
|
||||||
|
|
||||||
# runner
|
# runner
|
||||||
def print_timers(timers: Iterable[TMeasurement]):
|
def print_timers(timers: List[TMeasurement]):
|
||||||
compare = TBenchmark.Compare(timers)
|
compare = TBenchmark.Compare(timers)
|
||||||
compare.print()
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
def run(dtype: torch.dtype, sweep_schedules: bool,
|
def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
types = TypeConfig(
|
||||||
|
act_type=args.act_type,
|
||||||
|
weight_type=scalar_types.uint4b8 if args.group_zero_type is None \
|
||||||
|
else scalar_types.uint4,
|
||||||
|
output_type=args.out_type,
|
||||||
|
group_scale_type=args.group_scale_type,
|
||||||
|
group_zero_type=args.group_zero_type,
|
||||||
|
channel_scale_type=args.channel_scale_type,
|
||||||
|
token_scale_type=args.token_scale_type,
|
||||||
|
)
|
||||||
|
|
||||||
results = []
|
results: List[TMeasurement] = []
|
||||||
for m, k, n in MKNs:
|
for m, k, n in MKNs:
|
||||||
timers = bench(dtype,
|
timers = bench(types,
|
||||||
scalar_types.uint4b8,
|
args.group_size,
|
||||||
128,
|
|
||||||
m,
|
m,
|
||||||
k,
|
k,
|
||||||
n,
|
n,
|
||||||
f"{dtype}-gemm",
|
f"{args.act_type}-gemm",
|
||||||
f"MKN=({m}x{k}x{n})",
|
f"MKN=({m}x{k}x{n})",
|
||||||
sweep_schedules=sweep_schedules)
|
sweep_schedules=args.sweep_schedules)
|
||||||
print_timers(timers)
|
print_timers(timers)
|
||||||
results.extend(timers)
|
results.extend(timers)
|
||||||
|
|
||||||
@ -240,7 +446,7 @@ def run(dtype: torch.dtype, sweep_schedules: bool,
|
|||||||
|
|
||||||
# output makers
|
# output makers
|
||||||
def make_output(
|
def make_output(
|
||||||
data: Iterable[TMeasurement],
|
data: List[TMeasurement],
|
||||||
MKNs: Iterable[Tuple[int, int, int]],
|
MKNs: Iterable[Tuple[int, int, int]],
|
||||||
base_description: str,
|
base_description: str,
|
||||||
timestamp=None,
|
timestamp=None,
|
||||||
@ -262,7 +468,6 @@ def run_square_bench(args):
|
|||||||
dim_sizes = list(
|
dim_sizes = list(
|
||||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||||
|
|
||||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||||
|
|
||||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||||
@ -306,33 +511,49 @@ def run_model_bench(args):
|
|||||||
for k, n in KNs:
|
for k, n in KNs:
|
||||||
MKNs.append((m, k, n))
|
MKNs.append((m, k, n))
|
||||||
|
|
||||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
data = run(args, MKNs)
|
||||||
model_bench_data.append(data)
|
model_bench_data.append(data)
|
||||||
|
|
||||||
|
type_string = f"{args.act_type}"
|
||||||
|
|
||||||
# Print all results
|
# Print all results
|
||||||
for data, model_tp in zip(model_bench_data, models_tps):
|
for data, model_tp in zip(model_bench_data, models_tps):
|
||||||
model, tp_size = model_tp
|
model, tp_size = model_tp
|
||||||
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
|
print(f"== Results {type_string} {model}-TP{tp_size} ====")
|
||||||
print_timers(data)
|
print_timers(data)
|
||||||
|
|
||||||
timestamp = int(time.time())
|
timestr = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
|
||||||
all_data = []
|
all_results = []
|
||||||
for d in model_bench_data:
|
for d in model_bench_data:
|
||||||
all_data.extend(d)
|
all_results.extend(d)
|
||||||
|
|
||||||
# pickle all data
|
# pickle all data
|
||||||
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
|
with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f:
|
||||||
pkl.dump(all_data, f)
|
args_dict = vars(args)
|
||||||
|
args_dict.pop("func")
|
||||||
|
pkl.dump({
|
||||||
|
"args": args_dict,
|
||||||
|
"results": all_results,
|
||||||
|
}, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
def to_torch_dtype(dt):
|
def to_torch_dtype(dt):
|
||||||
if dt == "bfloat16":
|
return {
|
||||||
return torch.bfloat16
|
"bfloat16": torch.bfloat16,
|
||||||
if dt == "float16":
|
"float16": torch.float16,
|
||||||
return torch.float16
|
"int8": torch.int8,
|
||||||
raise ValueError("unsupported dtype")
|
"float8_e4m3fn": torch.float8_e4m3fn,
|
||||||
|
"int": torch.int,
|
||||||
|
"float": torch.float,
|
||||||
|
}[dt]
|
||||||
|
|
||||||
|
class ToTorchDtype(argparse.Action):
|
||||||
|
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
setattr(namespace, self.dest, to_torch_dtype(values))
|
||||||
|
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="""
|
description="""
|
||||||
@ -352,12 +573,42 @@ Benchmark Machete GEMM.
|
|||||||
""", # noqa: E501
|
""", # noqa: E501
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dtype",
|
"--act-type",
|
||||||
type=to_torch_dtype,
|
action=ToTorchDtype,
|
||||||
required=True,
|
required=True,
|
||||||
help="Available options are ['bfloat16', 'float16']",
|
choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--group-scale-type",
|
||||||
|
action=ToTorchDtype,
|
||||||
|
choices=['bfloat16', 'float16'],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--group-zero-type",
|
||||||
|
type=to_torch_dtype,
|
||||||
|
choices=['bfloat16', 'float16'],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--channel-scale-type",
|
||||||
|
action=ToTorchDtype,
|
||||||
|
choices=['float'],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--token-scale-type",
|
||||||
|
action=ToTorchDtype,
|
||||||
|
choices=['float'],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--out-type",
|
||||||
|
action=ToTorchDtype,
|
||||||
|
choices=['bfloat16', 'float16'],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--group-size",
|
||||||
|
type=int,
|
||||||
|
help="Available options are ['None', '-1', '128'], default=128",
|
||||||
|
default=128,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sweep-schedules",
|
"--sweep-schedules",
|
||||||
|
@ -20,10 +20,11 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with open(args.filename, 'rb') as f:
|
with open(args.filename, 'rb') as f:
|
||||||
data: List[TMeasurement] = pickle.load(f)
|
data = pickle.load(f)
|
||||||
|
raw_results: List[TMeasurement] = data["results"]
|
||||||
|
|
||||||
results = defaultdict(lambda: list())
|
results = defaultdict(lambda: list())
|
||||||
for v in data:
|
for v in raw_results:
|
||||||
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
|
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
KN = result.group(1)
|
KN = result.group(1)
|
||||||
|
@ -40,4 +40,10 @@ WEIGHT_SHAPES = {
|
|||||||
([8192, 57344], 1),
|
([8192, 57344], 1),
|
||||||
([28672, 8192], 0),
|
([28672, 8192], 0),
|
||||||
],
|
],
|
||||||
|
"meta-llama/Llama-3.1-405b-hf": [
|
||||||
|
([16384, 18432], 1),
|
||||||
|
([16384, 16384], 0),
|
||||||
|
([16384, 106496], 1),
|
||||||
|
([53248, 16384], 0),
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
|
|||||||
// is the layout f(x) = x
|
// is the layout f(x) = x
|
||||||
template <typename Layout>
|
template <typename Layout>
|
||||||
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
|
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
|
||||||
if constexpr (std::is_same_v<Layout, void>)
|
if constexpr (std::is_same_v<Layout, void>) {
|
||||||
return true;
|
return true;
|
||||||
else {
|
} else {
|
||||||
constexpr auto coalesced_layout = coalesce(Layout{});
|
constexpr auto coalesced_layout = coalesce(Layout{});
|
||||||
if constexpr (rank(coalesced_layout) == 1 &&
|
if constexpr (rank(coalesced_layout) == 1 &&
|
||||||
stride<0>(coalesced_layout) == 1) {
|
stride<0>(coalesced_layout) == 1) {
|
||||||
|
@ -52,6 +52,7 @@
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
|
|
||||||
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
||||||
|
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||||
#include "cute/tensor.hpp"
|
#include "cute/tensor.hpp"
|
||||||
|
|
||||||
namespace cutlass::epilogue::threadblock {
|
namespace cutlass::epilogue::threadblock {
|
317
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
Normal file
317
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
Normal file
@ -0,0 +1,317 @@
|
|||||||
|
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
|
||||||
|
|
||||||
|
/*
|
||||||
|
This file defines custom epilogues for fusing channel scales, token scales,
|
||||||
|
bias, and activation zero-points onto a GEMM operation using the
|
||||||
|
CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
|
||||||
|
|
||||||
|
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
|
||||||
|
as well as a static prepare_args function that constructs an
|
||||||
|
EVTCompute::Arguments struct.
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace vllm::c2x {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This class provides the common load descriptors for the
|
||||||
|
* ScaledEpilogue[...] classes
|
||||||
|
*/
|
||||||
|
template <typename ElementD, typename OutputTileThreadMap>
|
||||||
|
struct ScaledEpilogueBase {
|
||||||
|
protected:
|
||||||
|
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using ColOrScalarLoad =
|
||||||
|
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||||
|
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using RowOrScalarLoad =
|
||||||
|
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||||
|
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||||
|
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||||
|
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using RowOrZeroLoad =
|
||||||
|
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
|
||||||
|
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
|
// This utility function constructs the arguments for the load descriptors
|
||||||
|
// from a tensor. It can handle both row and column, as well as row/column or
|
||||||
|
// scalar cases.
|
||||||
|
template <typename Descriptor, typename T>
|
||||||
|
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||||
|
using Arguments = typename Descriptor::Arguments;
|
||||||
|
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||||
|
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||||
|
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||||
|
return Arguments{data_ptr, tensor.numel() != 1};
|
||||||
|
} else {
|
||||||
|
// it would technically work but no use case as data_ptr is never nullptr
|
||||||
|
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||||
|
return Arguments{data_ptr};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This overload handles the case where there might not be a tensor, in which
|
||||||
|
// case a nullptr is passed and a constant (0) is used.
|
||||||
|
template <typename Descriptor, typename T>
|
||||||
|
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||||
|
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||||
|
using Arguments = typename Descriptor::Arguments;
|
||||||
|
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||||
|
return Arguments{data_ptr};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
This epilogue function defines a quantized GEMM operation similar to
|
||||||
|
torch._scaled_mm.
|
||||||
|
|
||||||
|
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||||
|
per-row. B can be quantized per-tensor or per-column.
|
||||||
|
Any combination of per-tensor and per-row or column is supported.
|
||||||
|
A and B must have symmetric quantization (zero point == 0).
|
||||||
|
|
||||||
|
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||||
|
scales are applied elementwise with numpy-style broadcasting.
|
||||||
|
|
||||||
|
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||||
|
the A and B operands respectively. These scales may be either per-tensor or
|
||||||
|
per row or column.
|
||||||
|
*/
|
||||||
|
template <typename ElementD, typename OutputTileThreadMap>
|
||||||
|
struct ScaledEpilogue
|
||||||
|
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||||
|
private:
|
||||||
|
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||||
|
using Accum = typename SUPER::Accum;
|
||||||
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
|
||||||
|
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTCompute0 =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||||
|
|
||||||
|
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiplies, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
|
||||||
|
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||||
|
return ArgumentType{a_args, evt0_args};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||||
|
* This bias can also be used in the per-tensor azp case, where the activation
|
||||||
|
* zero point (azp) is used to compute an azp correction term,
|
||||||
|
* which is folded into the bias.
|
||||||
|
*
|
||||||
|
* The bias tensor must be per-output channel.
|
||||||
|
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||||
|
*/
|
||||||
|
template <typename ElementD, typename OutputTileThreadMap>
|
||||||
|
struct ScaledEpilogueBias
|
||||||
|
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||||
|
protected:
|
||||||
|
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||||
|
using Accum = typename SUPER::Accum;
|
||||||
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||||
|
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTCompute0 =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||||
|
|
||||||
|
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiply_add, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||||
|
EVTCompute0, Bias>;
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& bias) {
|
||||||
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
|
||||||
|
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||||
|
return ArgumentType{a_args, evt0_args, bias_args};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This epilogue directly supports per-tensor azp in int32 form.
|
||||||
|
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||||
|
* term, which should already be multiplied with the scalar azp.
|
||||||
|
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||||
|
*
|
||||||
|
* This epilogue also supports bias, which remains per-channel.
|
||||||
|
*/
|
||||||
|
template <typename ElementD, typename OutputTileThreadMap>
|
||||||
|
struct ScaledEpilogueBiasAzp
|
||||||
|
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||||
|
private:
|
||||||
|
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||||
|
using Accum = typename SUPER::Accum;
|
||||||
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||||
|
|
||||||
|
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||||
|
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||||
|
|
||||||
|
// Compute float(accum - azp_adj), both operands are int32_t
|
||||||
|
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::minus, float, int32_t,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeAzp =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||||
|
|
||||||
|
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeScaleB =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||||
|
EVTComputeAzp>;
|
||||||
|
|
||||||
|
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiply_add, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||||
|
EVTComputeScaleB, Bias>;
|
||||||
|
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
auto azp_adj_args =
|
||||||
|
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||||
|
|
||||||
|
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||||
|
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||||
|
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This epilogue supports per-token azp by computing and applying
|
||||||
|
* the correction term using a rank-1 update. If the term were materialized,
|
||||||
|
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||||
|
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||||
|
* point for each row of A.
|
||||||
|
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||||
|
*
|
||||||
|
* This epilogue also supports bias, which remains per-channel.
|
||||||
|
*/
|
||||||
|
template <typename ElementD, typename OutputTileThreadMap>
|
||||||
|
struct ScaledEpilogueBiasAzpToken
|
||||||
|
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||||
|
private:
|
||||||
|
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||||
|
using Accum = typename SUPER::Accum;
|
||||||
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||||
|
|
||||||
|
// Per-token azp term, shape (m,1)
|
||||||
|
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||||
|
|
||||||
|
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||||
|
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||||
|
|
||||||
|
// Compute azp * azp_adj
|
||||||
|
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiplies, int32_t, int32_t,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeAzp =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
|
||||||
|
|
||||||
|
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||||
|
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::minus, float, int32_t,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeAcc =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||||
|
|
||||||
|
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeScaleB =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||||
|
EVTComputeAcc>;
|
||||||
|
|
||||||
|
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiply_add, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||||
|
EVTComputeScaleB, Bias>;
|
||||||
|
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
torch::Tensor const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||||
|
auto azp_adj_args =
|
||||||
|
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||||
|
|
||||||
|
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||||
|
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||||
|
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||||
|
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}; // namespace vllm::c2x
|
315
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
Normal file
315
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
Normal file
@ -0,0 +1,315 @@
|
|||||||
|
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||||
|
|
||||||
|
/*
|
||||||
|
This file defines custom epilogues for fusing channel scales, token scales,
|
||||||
|
bias, and activation zero-points onto a GEMM operation using the
|
||||||
|
CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
|
||||||
|
|
||||||
|
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
|
||||||
|
as well as a static prepare_args function that constructs an
|
||||||
|
EVTCompute::Arguments struct.
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace vllm::c3x {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This class provides the common load descriptors for the
|
||||||
|
* ScaledEpilogue[...] classes
|
||||||
|
*/
|
||||||
|
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||||
|
struct ScaledEpilogueBase {
|
||||||
|
protected:
|
||||||
|
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||||
|
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||||
|
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||||
|
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||||
|
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
|
// Don't want to support nullptr by default
|
||||||
|
template <typename T, bool EnableNullPtr = false>
|
||||||
|
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||||
|
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||||
|
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||||
|
|
||||||
|
// Don't want to support nullptr by default
|
||||||
|
template <typename T, bool EnableNullPtr = false>
|
||||||
|
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||||
|
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||||
|
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||||
|
|
||||||
|
// This utility function constructs the arguments for the load descriptors
|
||||||
|
// from a tensor. It can handle both row and column, as well as row/column or
|
||||||
|
// scalar cases.
|
||||||
|
template <typename Descriptor, typename T>
|
||||||
|
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||||
|
using Arguments = typename Descriptor::Arguments;
|
||||||
|
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||||
|
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||||
|
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||||
|
return Arguments{data_ptr, tensor.numel() != 1};
|
||||||
|
} else {
|
||||||
|
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
|
||||||
|
!std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||||
|
return Arguments{data_ptr};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This overload handles the case where there might not be a tensor, in which
|
||||||
|
// case a nullptr is passed and a constant (0) is used.
|
||||||
|
template <typename Descriptor, typename T>
|
||||||
|
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||||
|
using Arguments = typename Descriptor::Arguments;
|
||||||
|
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||||
|
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||||
|
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||||
|
return Arguments{data_ptr};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
This epilogue function defines a quantized GEMM operation similar to
|
||||||
|
torch.scaled_mm_.
|
||||||
|
|
||||||
|
A and B may be both either int8 or fp8_e4m3. A can be
|
||||||
|
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
|
||||||
|
Any combination of per-tensor and per-row or column is supported.
|
||||||
|
A and B must have symmetric quantization (zero point == 0).
|
||||||
|
|
||||||
|
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||||
|
scales are applied elementwise with numpy-style broadcasting.
|
||||||
|
|
||||||
|
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||||
|
the A and B operands respectively. These scales may be either per-tensor or
|
||||||
|
per row or column.
|
||||||
|
*/
|
||||||
|
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||||
|
struct ScaledEpilogue
|
||||||
|
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||||
|
private:
|
||||||
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||||
|
using Accum = typename SUPER::Accum;
|
||||||
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
|
||||||
|
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTCompute0 =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||||
|
|
||||||
|
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
|
||||||
|
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||||
|
return ArgumentType{a_args, evt0_args};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||||
|
* This bias can also be used in the per-tensor azp case, where the activation
|
||||||
|
* zero point (azp) is used to compute an azp correction term,
|
||||||
|
* which is folded into the bias.
|
||||||
|
*
|
||||||
|
* The bias tensor must be per-output channel.
|
||||||
|
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||||
|
*/
|
||||||
|
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||||
|
struct ScaledEpilogueBias
|
||||||
|
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||||
|
private:
|
||||||
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||||
|
using Accum = typename SUPER::Accum;
|
||||||
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||||
|
|
||||||
|
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTCompute0 =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||||
|
|
||||||
|
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiply_add, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||||
|
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& bias) {
|
||||||
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
|
||||||
|
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||||
|
return ArgumentType{a_args, evt0_args, bias_args};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This epilogue directly supports per-tensor azp in int32 form.
|
||||||
|
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||||
|
* term, which should already be multiplied with the scalar azp.
|
||||||
|
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||||
|
*
|
||||||
|
* This epilogue also supports bias, which remains per-channel.
|
||||||
|
*/
|
||||||
|
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||||
|
struct ScaledEpilogueBiasAzp
|
||||||
|
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||||
|
private:
|
||||||
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||||
|
using Accum = typename SUPER::Accum;
|
||||||
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||||
|
|
||||||
|
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||||
|
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||||
|
|
||||||
|
// Compute float(accum - azp_adj), both operands are int32_t
|
||||||
|
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::minus, float, int32_t,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeAzp =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||||
|
|
||||||
|
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeScaleB =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||||
|
|
||||||
|
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiply_add, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||||
|
EVTComputeScaleB, Bias>;
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
auto azp_adj_args =
|
||||||
|
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||||
|
|
||||||
|
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||||
|
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||||
|
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This epilogue supports per-token azp by computing and applying
|
||||||
|
* the correction term using a rank-1 update. If the term were materialized,
|
||||||
|
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||||
|
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||||
|
* point for each row of A.
|
||||||
|
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||||
|
*
|
||||||
|
* This epilogue also supports bias, which remains per-channel.
|
||||||
|
*/
|
||||||
|
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||||
|
struct ScaledEpilogueBiasAzpToken
|
||||||
|
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||||
|
private:
|
||||||
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||||
|
using Accum = typename SUPER::Accum;
|
||||||
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||||
|
|
||||||
|
// Per-token azp term, shape (m,1)
|
||||||
|
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||||
|
|
||||||
|
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||||
|
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||||
|
|
||||||
|
// Compute azp * azp_adj
|
||||||
|
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, int32_t, int32_t,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeAzp =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
|
||||||
|
|
||||||
|
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||||
|
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::minus, float, int32_t,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeAcc =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||||
|
|
||||||
|
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeScaleB =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||||
|
|
||||||
|
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiply_add, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||||
|
EVTComputeScaleB, Bias>;
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
torch::Tensor const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||||
|
auto azp_adj_args =
|
||||||
|
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||||
|
|
||||||
|
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||||
|
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||||
|
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||||
|
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}; // namespace vllm::c3x
|
@ -35,6 +35,35 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
|
||||||
|
**DataTypeSize, # type: ignore
|
||||||
|
**{
|
||||||
|
VLLMDataType.u4b8: 4,
|
||||||
|
VLLMDataType.u8b128: 8,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
||||||
|
VLLMDataType.u4b8: "vllm::kU4B8",
|
||||||
|
VLLMDataType.u8b128: "vllm::kU8B128",
|
||||||
|
DataType.u4: "vllm::kU4",
|
||||||
|
DataType.u8: "vllm::kU8",
|
||||||
|
DataType.s4: "vllm::kS4",
|
||||||
|
DataType.s8: "vllm::kS8",
|
||||||
|
DataType.f16: "vllm::kFloat16",
|
||||||
|
DataType.bf16: "vllm::kBfloat16",
|
||||||
|
}
|
||||||
|
|
||||||
|
VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
||||||
|
DataType.u8: "at::ScalarType::Byte",
|
||||||
|
DataType.s8: "at::ScalarType::Char",
|
||||||
|
DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
|
||||||
|
DataType.s32: "at::ScalarType::Int",
|
||||||
|
DataType.f16: "at::ScalarType::Half",
|
||||||
|
DataType.bf16: "at::ScalarType::BFloat16",
|
||||||
|
DataType.f32: "at::ScalarType::Float",
|
||||||
|
}
|
||||||
|
|
||||||
VLLMKernelScheduleTag: Dict[Union[
|
VLLMKernelScheduleTag: Dict[Union[
|
||||||
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
||||||
**KernelScheduleTag, # type: ignore
|
**KernelScheduleTag, # type: ignore
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include "cutlass/numeric_conversion.h"
|
#include "cutlass/numeric_conversion.h"
|
||||||
#include "cutlass_extensions/vllm_custom_types.cuh"
|
#include "cutlass_extensions/vllm_custom_types.cuh"
|
||||||
#include "cutlass_extensions/cute_utils.cuh"
|
#include "cutlass_extensions/cute_utils.cuh"
|
||||||
|
#include "cutlass_extensions/vllm_type_utils.cuh"
|
||||||
|
|
||||||
// this file extends:
|
// this file extends:
|
||||||
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
|
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
|
||||||
@ -28,8 +29,19 @@ struct InterleavedNumericArrayConverter {
|
|||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
static result_type convert(source_type const& source) {
|
static result_type convert(source_type const& source) {
|
||||||
CUTE_INVALID_CONTROL_PATH(
|
if (cute::elect_one_sync()) {
|
||||||
"InterleavedNumericArrayConverter not implemented\n");
|
if constexpr (std::is_same_v<IlvBlkLayout, void>) {
|
||||||
|
printf(
|
||||||
|
"Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n",
|
||||||
|
nameof_v<T>, nameof_v<S>, N);
|
||||||
|
} else {
|
||||||
|
printf(
|
||||||
|
"Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not "
|
||||||
|
"implemented\n",
|
||||||
|
nameof_v<T>, nameof_v<S>, N, size(IlvBlkLayout{}));
|
||||||
|
}
|
||||||
|
__brkpt();
|
||||||
|
}
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,11 +68,6 @@ struct InterleavedNumericArrayConverter<
|
|||||||
result_type operator()(source_type const& s) const { return convert(s); }
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO (LucasWilkinson): Implement
|
|
||||||
// for Array<cutlass::float8_e4m3fn, N> <= Array<vllm_uint4b8_t, N>
|
|
||||||
|
|
||||||
// ....
|
|
||||||
|
|
||||||
template <typename RegConvert32bit, typename T, typename S, int N>
|
template <typename RegConvert32bit, typename T, typename S, int N>
|
||||||
struct ArrayConverterPacked32Bit {
|
struct ArrayConverterPacked32Bit {
|
||||||
using result_type = Array<T, N>;
|
using result_type = Array<T, N>;
|
||||||
@ -86,14 +93,16 @@ struct ArrayConverterPacked32Bit {
|
|||||||
using ScalarConverter = NumericConverter<T, S>;
|
using ScalarConverter = NumericConverter<T, S>;
|
||||||
|
|
||||||
template <typename PackedSrc>
|
template <typename PackedSrc>
|
||||||
CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) {
|
CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) {
|
||||||
if constexpr (sizeof(PackedSrc) == 1) {
|
if constexpr (sizeof(PackedSrc) == 1) {
|
||||||
return static_cast<uint32_t>(reinterpret_cast<const uint8_t&>(source));
|
return Array<uint32_t, 1>{reinterpret_cast<uint8_t const&>(src)};
|
||||||
} else if constexpr (sizeof(PackedSrc) == 2) {
|
} else if constexpr (sizeof(PackedSrc) == 2) {
|
||||||
return static_cast<uint32_t>(reinterpret_cast<const uint16_t&>(source));
|
return Array<uint32_t, 1>{reinterpret_cast<uint16_t const&>(src)};
|
||||||
|
} else if constexpr (sizeof(PackedSrc) == 4) {
|
||||||
|
return Array<uint32_t, 1>{reinterpret_cast<uint32_t const&>(src)};
|
||||||
} else {
|
} else {
|
||||||
static_assert(sizeof(PackedSrc) == 4);
|
static_assert(sizeof(PackedSrc) == 8);
|
||||||
return reinterpret_cast<const uint32_t&>(source);
|
return reinterpret_cast<Array<uint32_t, 2> const&>(src);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,7 +119,7 @@ struct ArrayConverterPacked32Bit {
|
|||||||
static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
|
static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
|
||||||
static_assert(std::is_same_v<typename PackedResultType::Element, T>);
|
static_assert(std::is_same_v<typename PackedResultType::Element, T>);
|
||||||
|
|
||||||
return RegConvert32bit::template convert<PackedResultType>(to_reg(source));
|
return RegConvert32bit::template convert<PackedResultType>(to_regs(source));
|
||||||
}
|
}
|
||||||
|
|
||||||
friend class detail::VectorizedConverter;
|
friend class detail::VectorizedConverter;
|
||||||
@ -140,6 +149,131 @@ struct ArrayConverterPacked32Bit {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed
|
||||||
|
// into 2 32bit register.
|
||||||
|
template <uint8_t LUT0, uint8_t LUT1, uint8_t LUT2, uint8_t LUT3, //
|
||||||
|
uint8_t LUT4, uint8_t LUT5, uint8_t LUT6, uint8_t LUT7, //
|
||||||
|
uint8_t LUT8, uint8_t LUT9, uint8_t LUT10, uint8_t LUT11, //
|
||||||
|
uint8_t LUT12, uint8_t LUT13, uint8_t LUT14, uint8_t LUT15>
|
||||||
|
CUTLASS_DEVICE cutlass::AlignedArray<uint32_t, 2> lut_4bit_to_8bit_convert(
|
||||||
|
uint32_t src) {
|
||||||
|
cutlass::AlignedArray<uint32_t, 2> r;
|
||||||
|
// Determines if the value is in the top half of the LUT if set or
|
||||||
|
// (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move
|
||||||
|
// into bit position 0x4 of each nibble so when or'd with final_prmt_base it
|
||||||
|
// selects the correct candidate. When elements in final_prmt_base
|
||||||
|
// are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements
|
||||||
|
// are < 0x4, the low candidate is selected (i.e. LUT[0:7])
|
||||||
|
uint32_t high_bit = (src & 0x88888888) >> 1;
|
||||||
|
|
||||||
|
// `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT
|
||||||
|
// (selects correct high or low candidate)
|
||||||
|
const uint32_t final_prmt_base = 0x32103210;
|
||||||
|
|
||||||
|
// Ignore the high bit when indexing into LUT, for each 4bit value
|
||||||
|
// we index into both the high and low candidates then use
|
||||||
|
// high_bit | final_prmt_base to select the correct candidate
|
||||||
|
uint32_t lut_idx = (src & 0x77777777);
|
||||||
|
|
||||||
|
auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) {
|
||||||
|
return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) |
|
||||||
|
(uint32_t(d) << 24);
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3);
|
||||||
|
static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7);
|
||||||
|
static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11);
|
||||||
|
static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) {
|
||||||
|
uint32_t final_prmt_idx = final_prmt_base | high_bit;
|
||||||
|
|
||||||
|
// This uses a look up table to convert packed int4s to packed int8s,
|
||||||
|
// using the int4 value as the index to prmt. It first select both the
|
||||||
|
// high and low candidates, then uses the high bit (i.e. `high_bit`) to
|
||||||
|
// select the correct candidate.
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" .reg .b32 low, high;\n"
|
||||||
|
" prmt.b32 low, %1, %2, %5;\n"
|
||||||
|
" prmt.b32 high, %3, %4, %5;\n"
|
||||||
|
" prmt.b32 %0, low, high, %6;\n"
|
||||||
|
"}\n"
|
||||||
|
: "=r"(r[ii])
|
||||||
|
: "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx),
|
||||||
|
"r"(final_prmt_idx));
|
||||||
|
}
|
||||||
|
|
||||||
|
return r;
|
||||||
|
};
|
||||||
|
|
||||||
|
// for Array<int8_t, N> <= Array<vllm_uint4b8_t, N>
|
||||||
|
template <FloatRoundStyle Round, int N>
|
||||||
|
struct NumericArrayConverter<int8_t, vllm_uint4b8_t, N, Round> {
|
||||||
|
using result_type = Array<int8_t, N>;
|
||||||
|
using source_type = Array<vllm_uint4b8_t, N>;
|
||||||
|
|
||||||
|
static FloatRoundStyle const round_style = Round;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct RegConvert {
|
||||||
|
template <typename PackedResultType>
|
||||||
|
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||||
|
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s
|
||||||
|
auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, //
|
||||||
|
0xFC, 0xFD, 0xFE, 0xFF, //
|
||||||
|
0x00, 0x01, 0x02, 0x03, //
|
||||||
|
0x04, 0x05, 0x06, 0x07>(src_[0]);
|
||||||
|
return reinterpret_cast<PackedResultType&>(r);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||||
|
typename source_type::Element,
|
||||||
|
N>::convert(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// for Array<cutlass::float_e4m3_t, N> <= Array<vllm_uint4b8_t, N>
|
||||||
|
template <FloatRoundStyle Round, int N>
|
||||||
|
struct NumericArrayConverter<cutlass::float_e4m3_t, vllm_uint4b8_t, N, Round> {
|
||||||
|
using result_type = Array<cutlass::float_e4m3_t, N>;
|
||||||
|
using source_type = Array<vllm_uint4b8_t, N>;
|
||||||
|
|
||||||
|
static FloatRoundStyle const round_style = Round;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct RegConvert {
|
||||||
|
template <typename PackedResultType>
|
||||||
|
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||||
|
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s
|
||||||
|
auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, //
|
||||||
|
0xC8, 0xC4, 0xC0, 0xB8, //
|
||||||
|
0x00, 0x38, 0x40, 0x44, //
|
||||||
|
0x48, 0x4A, 0x4C, 0x4E>(src_[0]);
|
||||||
|
return reinterpret_cast<PackedResultType&>(r);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||||
|
typename source_type::Element,
|
||||||
|
N>::convert(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
|
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
|
||||||
template <FloatRoundStyle Round, int N>
|
template <FloatRoundStyle Round, int N>
|
||||||
struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
|
struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
|
||||||
@ -148,7 +282,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
|
|||||||
|
|
||||||
struct RegConvert {
|
struct RegConvert {
|
||||||
template <typename PackedResultType>
|
template <typename PackedResultType>
|
||||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||||
|
uint32_t src = src_[0];
|
||||||
using RegArray =
|
using RegArray =
|
||||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
sizeof(PackedResultType)>;
|
sizeof(PackedResultType)>;
|
||||||
@ -249,7 +384,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
|||||||
private:
|
private:
|
||||||
struct RegConvert {
|
struct RegConvert {
|
||||||
template <typename PackedResultType>
|
template <typename PackedResultType>
|
||||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||||
|
uint32_t src = src_[0];
|
||||||
using RegArray =
|
using RegArray =
|
||||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
sizeof(PackedResultType)>;
|
sizeof(PackedResultType)>;
|
||||||
@ -338,7 +474,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
|||||||
private:
|
private:
|
||||||
struct RegConvert {
|
struct RegConvert {
|
||||||
template <typename PackedResultType>
|
template <typename PackedResultType>
|
||||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||||
|
uint32_t src = src_[0];
|
||||||
using RegArray =
|
using RegArray =
|
||||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
sizeof(PackedResultType)>;
|
sizeof(PackedResultType)>;
|
||||||
@ -417,7 +554,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint8b128_t, N, Round> {
|
|||||||
|
|
||||||
struct RegConvert {
|
struct RegConvert {
|
||||||
template <typename PackedResultType>
|
template <typename PackedResultType>
|
||||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||||
|
uint32_t src = src_[0];
|
||||||
// Hold output FP16s in reg. We need 1 reg for every 2 elements
|
// Hold output FP16s in reg. We need 1 reg for every 2 elements
|
||||||
using RegArray =
|
using RegArray =
|
||||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
@ -469,7 +607,8 @@ struct NumericArrayConverter<float, vllm_uint8b128_t, N, Round> {
|
|||||||
private:
|
private:
|
||||||
struct RegConvert {
|
struct RegConvert {
|
||||||
template <typename PackedResultType>
|
template <typename PackedResultType>
|
||||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||||
|
uint32_t src = src_[0];
|
||||||
PackedResultType r;
|
PackedResultType r;
|
||||||
|
|
||||||
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
|
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
|
||||||
@ -513,7 +652,8 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> {
|
|||||||
private:
|
private:
|
||||||
struct RegConvert {
|
struct RegConvert {
|
||||||
template <typename PackedResultType>
|
template <typename PackedResultType>
|
||||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) {
|
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||||
|
uint32_t src_reg = src_[0];
|
||||||
// Hold output BF16s in reg. We need 1 reg for every 2 elements
|
// Hold output BF16s in reg. We need 1 reg for every 2 elements
|
||||||
using RegArray =
|
using RegArray =
|
||||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
@ -603,7 +743,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
|||||||
private:
|
private:
|
||||||
struct RegConvert {
|
struct RegConvert {
|
||||||
template <typename PackedResultType>
|
template <typename PackedResultType>
|
||||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||||
|
uint32_t src = src_[0];
|
||||||
using RegArray =
|
using RegArray =
|
||||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
sizeof(PackedResultType)>;
|
sizeof(PackedResultType)>;
|
||||||
@ -671,7 +812,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
|||||||
private:
|
private:
|
||||||
struct RegConvert {
|
struct RegConvert {
|
||||||
template <typename PackedResultType>
|
template <typename PackedResultType>
|
||||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||||
|
uint32_t src = src_[0];
|
||||||
using RegArray =
|
using RegArray =
|
||||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
sizeof(PackedResultType)>;
|
sizeof(PackedResultType)>;
|
||||||
@ -788,6 +930,61 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint8b128_t, N, Round> {
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// for Array<int8_t, N> <= Array<cutlass::half_t, N>
|
||||||
|
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
|
||||||
|
template <FloatRoundStyle Round, int N>
|
||||||
|
struct NumericArrayConverter<int8_t, cutlass::half_t, N, Round> {
|
||||||
|
using result_type = Array<int8_t, N>;
|
||||||
|
using source_type = Array<cutlass::half_t, N>;
|
||||||
|
|
||||||
|
struct RegConvert {
|
||||||
|
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
|
||||||
|
template <typename PackedResultType, int src_regs>
|
||||||
|
CUTLASS_DEVICE static PackedResultType convert(
|
||||||
|
Array<uint32_t, src_regs> src) {
|
||||||
|
// Hold output int8s in reg. We need 1 reg for every 4 elements
|
||||||
|
using RegArray = cutlass::AlignedArray<
|
||||||
|
uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>;
|
||||||
|
RegArray r;
|
||||||
|
|
||||||
|
static constexpr uint32_t MAGIC_BIAS_ = 0x64806480;
|
||||||
|
auto MAGIC_BIAS = *reinterpret_cast<const half2*>(&MAGIC_BIAS_);
|
||||||
|
|
||||||
|
*reinterpret_cast<half2*>(&src[0]) =
|
||||||
|
__hadd2(*reinterpret_cast<half2*>(&src[0]), MAGIC_BIAS);
|
||||||
|
|
||||||
|
if constexpr (src_regs > 1) {
|
||||||
|
*reinterpret_cast<half2*>(&src[1]) =
|
||||||
|
__hadd2(*reinterpret_cast<half2*>(&src[1]), MAGIC_BIAS);
|
||||||
|
}
|
||||||
|
|
||||||
|
static_assert(PackedResultType::kElements <= 4);
|
||||||
|
uint32_t uint8s;
|
||||||
|
static constexpr uint32_t MASK_0246 = 0x6420;
|
||||||
|
static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
|
||||||
|
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
|
||||||
|
: "=r"(uint8s)
|
||||||
|
: "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]),
|
||||||
|
"n"(MASK_0246));
|
||||||
|
|
||||||
|
uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK);
|
||||||
|
|
||||||
|
return reinterpret_cast<PackedResultType&>(int8s);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||||
|
typename source_type::Element,
|
||||||
|
N>::convert(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
} // namespace cutlass
|
} // namespace cutlass
|
||||||
|
42
csrc/cutlass_extensions/vllm_type_utils.cuh
Normal file
42
csrc/cutlass_extensions/vllm_type_utils.cuh
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
#include "cutlass/bfloat16.h"
|
||||||
|
#include "cutlass/half.h"
|
||||||
|
#include "cuda_bf16.h"
|
||||||
|
|
||||||
|
#include "cutlass_extensions/vllm_custom_types.cuh"
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct nameof {
|
||||||
|
static constexpr char const* value = "unknown";
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline constexpr auto nameof_v = nameof<T>::value;
|
||||||
|
|
||||||
|
#define NAMEOF_TYPE(T) \
|
||||||
|
template <> \
|
||||||
|
struct nameof<T> { \
|
||||||
|
static constexpr char const* value = #T; \
|
||||||
|
};
|
||||||
|
|
||||||
|
NAMEOF_TYPE(float_e4m3_t)
|
||||||
|
NAMEOF_TYPE(float_e5m2_t)
|
||||||
|
NAMEOF_TYPE(half_t)
|
||||||
|
NAMEOF_TYPE(nv_bfloat16)
|
||||||
|
NAMEOF_TYPE(bfloat16_t)
|
||||||
|
NAMEOF_TYPE(float)
|
||||||
|
|
||||||
|
NAMEOF_TYPE(int4b_t)
|
||||||
|
NAMEOF_TYPE(int8_t)
|
||||||
|
NAMEOF_TYPE(int32_t)
|
||||||
|
NAMEOF_TYPE(int64_t)
|
||||||
|
|
||||||
|
NAMEOF_TYPE(vllm_uint4b8_t)
|
||||||
|
NAMEOF_TYPE(uint4b_t)
|
||||||
|
NAMEOF_TYPE(uint8_t)
|
||||||
|
NAMEOF_TYPE(vllm_uint8b128_t)
|
||||||
|
NAMEOF_TYPE(uint32_t)
|
||||||
|
NAMEOF_TYPE(uint64_t)
|
||||||
|
|
||||||
|
}; // namespace cutlass
|
@ -8,6 +8,10 @@
|
|||||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||||
|
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
|
||||||
|
|
||||||
|
using namespace vllm;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||||
@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t,
|
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||||
Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
if (bias) {
|
if (bias) {
|
||||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||||
"currently bias dtype must match output dtype ", out.dtype());
|
"currently bias dtype must match output dtype ", out.dtype());
|
||||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBias>(
|
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
|
||||||
out, a, b, a_scales, b_scales, *bias);
|
out, a, b, a_scales, b_scales, *bias);
|
||||||
} else {
|
} else {
|
||||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogue>(
|
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
if (azp) {
|
if (azp) {
|
||||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||||
} else {
|
} else {
|
||||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
|
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||||
Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
if (bias) {
|
if (bias) {
|
||||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||||
"currently bias dtype must match output dtype ", out.dtype());
|
"currently bias dtype must match output dtype ", out.dtype());
|
||||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBias>(
|
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
|
||||||
out, a, b, a_scales, b_scales, *bias);
|
out, a, b, a_scales, b_scales, *bias);
|
||||||
} else {
|
} else {
|
||||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogue>(
|
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
if (azp) {
|
if (azp) {
|
||||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||||
} else {
|
} else {
|
||||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||||
Epilogue>(
|
Epilogue>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
} else {
|
} else {
|
||||||
assert(out.dtype() == torch::kFloat16);
|
assert(out.dtype() == torch::kFloat16);
|
||||||
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
|
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||||
Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return vllm::cutlass_gemm_sm89_fp8_dispatch<
|
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
|
cutlass::bfloat16_t, Epilogue>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
cutlass::half_t, Epilogue>(
|
cutlass::half_t, Epilogue>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
if (bias) {
|
if (bias) {
|
||||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||||
"currently bias dtype must match output dtype ", out.dtype());
|
"currently bias dtype must match output dtype ", out.dtype());
|
||||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBias>(
|
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
|
||||||
out, a, b, a_scales, b_scales, *bias);
|
out, a, b, a_scales, b_scales, *bias);
|
||||||
} else {
|
} else {
|
||||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogue>(
|
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
if (azp) {
|
if (azp) {
|
||||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||||
} else {
|
} else {
|
||||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,6 @@
|
|||||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||||
|
|
||||||
#include "broadcast_load_epilogue_c2x.hpp"
|
|
||||||
#include "common.hpp"
|
#include "common.hpp"
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*
|
|
||||||
* This class provides the common load descriptors for the
|
|
||||||
* ScaledEpilogue[...] classes
|
|
||||||
*/
|
|
||||||
template <typename ElementD, typename OutputTileThreadMap>
|
|
||||||
struct ScaledEpilogueBase {
|
|
||||||
protected:
|
|
||||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
using ColOrScalarLoad =
|
|
||||||
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
|
||||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
using RowOrScalarLoad =
|
|
||||||
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
|
||||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
|
||||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
|
||||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
using RowOrZeroLoad =
|
|
||||||
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
|
|
||||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
|
||||||
|
|
||||||
// This utility function constructs the arguments for the load descriptors
|
|
||||||
// from a tensor. It can handle both row and column, as well as row/column or
|
|
||||||
// scalar cases.
|
|
||||||
template <typename Descriptor, typename T>
|
|
||||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
|
||||||
using Arguments = typename Descriptor::Arguments;
|
|
||||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
|
||||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
|
||||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
|
||||||
return Arguments{data_ptr, tensor.numel() != 1};
|
|
||||||
} else {
|
|
||||||
// it would technically work but no use case as data_ptr is never nullptr
|
|
||||||
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
|
||||||
return Arguments{data_ptr};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This overload handles the case where there might not be a tensor, in which
|
|
||||||
// case a nullptr is passed and a constant (0) is used.
|
|
||||||
template <typename Descriptor, typename T>
|
|
||||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
|
||||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
|
||||||
using Arguments = typename Descriptor::Arguments;
|
|
||||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
|
||||||
return Arguments{data_ptr};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
|
||||||
This epilogue function defines a quantized GEMM operation similar to
|
|
||||||
torch._scaled_mm.
|
|
||||||
|
|
||||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
|
||||||
per-row. B can be quantized per-tensor or per-column.
|
|
||||||
Any combination of per-tensor and per-row or column is supported.
|
|
||||||
A and B must have symmetric quantization (zero point == 0).
|
|
||||||
|
|
||||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
|
||||||
scales are applied elementwise with numpy-style broadcasting.
|
|
||||||
|
|
||||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
|
||||||
the A and B operands respectively. These scales may be either per-tensor or
|
|
||||||
per row or column.
|
|
||||||
*/
|
|
||||||
template <typename ElementD, typename OutputTileThreadMap>
|
|
||||||
struct ScaledEpilogue
|
|
||||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
|
||||||
private:
|
|
||||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
|
||||||
using Accum = typename SUPER::Accum;
|
|
||||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
|
||||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
||||||
|
|
||||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
|
||||||
cutlass::multiplies, float, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTCompute0 =
|
|
||||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
|
||||||
|
|
||||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
|
||||||
cutlass::multiplies, ElementD, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
public:
|
|
||||||
using EVTCompute =
|
|
||||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales) {
|
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
|
||||||
|
|
||||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
|
||||||
return ArgumentType{a_args, evt0_args};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
|
||||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
|
||||||
* This bias can also be used in the per-tensor azp case, where the activation
|
|
||||||
* zero point (azp) is used to compute an azp correction term,
|
|
||||||
* which is folded into the bias.
|
|
||||||
*
|
|
||||||
* The bias tensor must be per-output channel.
|
|
||||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
|
||||||
*/
|
|
||||||
template <typename ElementD, typename OutputTileThreadMap>
|
|
||||||
struct ScaledEpilogueBias
|
|
||||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
|
||||||
protected:
|
|
||||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
|
||||||
using Accum = typename SUPER::Accum;
|
|
||||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
|
||||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
||||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
|
||||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
|
||||||
cutlass::multiplies, float, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTCompute0 =
|
|
||||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
|
||||||
|
|
||||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
|
||||||
cutlass::multiply_add, ElementD, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
public:
|
|
||||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
|
||||||
EVTCompute0, Bias>;
|
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& bias) {
|
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
|
||||||
|
|
||||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
|
||||||
return ArgumentType{a_args, evt0_args, bias_args};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
|
||||||
* This epilogue directly supports per-tensor azp in int32 form.
|
|
||||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
|
||||||
* term, which should already be multiplied with the scalar azp.
|
|
||||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
|
||||||
*
|
|
||||||
* This epilogue also supports bias, which remains per-channel.
|
|
||||||
*/
|
|
||||||
template <typename ElementD, typename OutputTileThreadMap>
|
|
||||||
struct ScaledEpilogueBiasAzp
|
|
||||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
|
||||||
private:
|
|
||||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
|
||||||
using Accum = typename SUPER::Accum;
|
|
||||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
|
||||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
||||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
|
||||||
|
|
||||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
|
||||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
|
||||||
|
|
||||||
// Compute float(accum - azp_adj), both operands are int32_t
|
|
||||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
|
||||||
cutlass::minus, float, int32_t,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTComputeAzp =
|
|
||||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
|
|
||||||
|
|
||||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
|
||||||
cutlass::multiplies, float, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTComputeScaleB =
|
|
||||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
|
||||||
EVTComputeAzp>;
|
|
||||||
|
|
||||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
|
||||||
cutlass::multiply_add, ElementD, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
public:
|
|
||||||
using EVTCompute =
|
|
||||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
|
||||||
EVTComputeScaleB, Bias>;
|
|
||||||
|
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
|
||||||
auto azp_adj_args =
|
|
||||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
|
||||||
|
|
||||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
|
||||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
|
||||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
|
||||||
* This epilogue supports per-token azp by computing and applying
|
|
||||||
* the correction term using a rank-1 update. If the term were materialized,
|
|
||||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
|
||||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
|
||||||
* point for each row of A.
|
|
||||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
|
||||||
*
|
|
||||||
* This epilogue also supports bias, which remains per-channel.
|
|
||||||
*/
|
|
||||||
template <typename ElementD, typename OutputTileThreadMap>
|
|
||||||
struct ScaledEpilogueBiasAzpToken
|
|
||||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
|
||||||
private:
|
|
||||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
|
||||||
using Accum = typename SUPER::Accum;
|
|
||||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
|
||||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
||||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
|
||||||
|
|
||||||
// Per-token azp term, shape (m,1)
|
|
||||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
|
||||||
|
|
||||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
|
||||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
|
||||||
|
|
||||||
// Compute azp * azp_adj
|
|
||||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
|
||||||
cutlass::multiplies, int32_t, int32_t,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTComputeAzp =
|
|
||||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
|
|
||||||
|
|
||||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
|
||||||
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
|
|
||||||
cutlass::minus, float, int32_t,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTComputeAcc =
|
|
||||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
|
||||||
|
|
||||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
|
||||||
cutlass::multiplies, float, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTComputeScaleB =
|
|
||||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
|
||||||
EVTComputeAcc>;
|
|
||||||
|
|
||||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
|
||||||
cutlass::multiply_add, ElementD, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
public:
|
|
||||||
using EVTCompute =
|
|
||||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
|
||||||
EVTComputeScaleB, Bias>;
|
|
||||||
|
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
torch::Tensor const& azp,
|
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
|
||||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
|
||||||
auto azp_adj_args =
|
|
||||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
|
||||||
|
|
||||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
|
||||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
|
||||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
|
||||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Arch, template <typename> typename ArchGuard,
|
template <typename Arch, template <typename> typename ArchGuard,
|
||||||
typename ElementAB_, typename ElementD_,
|
typename ElementAB_, typename ElementD_,
|
||||||
template <typename, typename> typename Epilogue_, typename TileShape,
|
template <typename, typename> typename Epilogue_, typename TileShape,
|
||||||
|
@ -23,11 +23,12 @@
|
|||||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
#include "broadcast_load_epilogue_c3x.hpp"
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
#include "common.hpp"
|
#include "common.hpp"
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
using namespace vllm;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||||
@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*
|
|
||||||
* This class provides the common load descriptors for the
|
|
||||||
* ScaledEpilogue[...] classes
|
|
||||||
*/
|
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
|
||||||
struct ScaledEpilogueBase {
|
|
||||||
protected:
|
|
||||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
|
||||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
|
||||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
|
||||||
|
|
||||||
// Don't want to support nullptr by default
|
|
||||||
template <typename T, bool EnableNullPtr = false>
|
|
||||||
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
|
||||||
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
|
||||||
|
|
||||||
// Don't want to support nullptr by default
|
|
||||||
template <typename T, bool EnableNullPtr = false>
|
|
||||||
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
|
||||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
|
||||||
|
|
||||||
// This utility function constructs the arguments for the load descriptors
|
|
||||||
// from a tensor. It can handle both row and column, as well as row/column or
|
|
||||||
// scalar cases.
|
|
||||||
template <typename Descriptor, typename T>
|
|
||||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
|
||||||
using Arguments = typename Descriptor::Arguments;
|
|
||||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
|
||||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
|
||||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
|
||||||
return Arguments{data_ptr, tensor.numel() != 1};
|
|
||||||
} else {
|
|
||||||
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
|
|
||||||
!std::is_same_v<Descriptor, RowLoad<T, true>>);
|
|
||||||
return Arguments{data_ptr};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This overload handles the case where there might not be a tensor, in which
|
|
||||||
// case a nullptr is passed and a constant (0) is used.
|
|
||||||
template <typename Descriptor, typename T>
|
|
||||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
|
||||||
using Arguments = typename Descriptor::Arguments;
|
|
||||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
|
||||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
|
||||||
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
|
||||||
return Arguments{data_ptr};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
|
||||||
This epilogue function defines a quantized GEMM operation similar to
|
|
||||||
torch.scaled_mm_.
|
|
||||||
|
|
||||||
A and B may be both either int8 or fp8_e4m3. A can be
|
|
||||||
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
|
|
||||||
Any combination of per-tensor and per-row or column is supported.
|
|
||||||
A and B must have symmetric quantization (zero point == 0).
|
|
||||||
|
|
||||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
|
||||||
scales are applied elementwise with numpy-style broadcasting.
|
|
||||||
|
|
||||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
|
||||||
the A and B operands respectively. These scales may be either per-tensor or
|
|
||||||
per row or column.
|
|
||||||
*/
|
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
|
||||||
struct ScaledEpilogue
|
|
||||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
|
||||||
private:
|
|
||||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
|
||||||
using Accum = typename SUPER::Accum;
|
|
||||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
|
||||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
||||||
|
|
||||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
|
||||||
cutlass::multiplies, float, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTCompute0 =
|
|
||||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
|
||||||
|
|
||||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
|
||||||
cutlass::multiplies, ElementD, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
public:
|
|
||||||
using EVTCompute =
|
|
||||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales) {
|
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
|
||||||
|
|
||||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
|
||||||
return ArgumentType{a_args, evt0_args};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
|
||||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
|
||||||
* This bias can also be used in the per-tensor azp case, where the activation
|
|
||||||
* zero point (azp) is used to compute an azp correction term,
|
|
||||||
* which is folded into the bias.
|
|
||||||
*
|
|
||||||
* The bias tensor must be per-output channel.
|
|
||||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
|
||||||
*/
|
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
|
||||||
struct ScaledEpilogueBias
|
|
||||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
|
||||||
private:
|
|
||||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
|
||||||
using Accum = typename SUPER::Accum;
|
|
||||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
|
||||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
||||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
|
||||||
|
|
||||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
|
||||||
cutlass::multiplies, float, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTCompute0 =
|
|
||||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
|
||||||
|
|
||||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
|
||||||
cutlass::multiply_add, ElementD, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
public:
|
|
||||||
using EVTCompute =
|
|
||||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
|
||||||
|
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& bias) {
|
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
|
||||||
|
|
||||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
|
||||||
return ArgumentType{a_args, evt0_args, bias_args};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
|
||||||
* This epilogue directly supports per-tensor azp in int32 form.
|
|
||||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
|
||||||
* term, which should already be multiplied with the scalar azp.
|
|
||||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
|
||||||
*
|
|
||||||
* This epilogue also supports bias, which remains per-channel.
|
|
||||||
*/
|
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
|
||||||
struct ScaledEpilogueBiasAzp
|
|
||||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
|
||||||
private:
|
|
||||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
|
||||||
using Accum = typename SUPER::Accum;
|
|
||||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
|
||||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
||||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
|
||||||
|
|
||||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
|
||||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
|
||||||
|
|
||||||
// Compute float(accum - azp_adj), both operands are int32_t
|
|
||||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
|
||||||
cutlass::minus, float, int32_t,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTComputeAzp =
|
|
||||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
|
|
||||||
|
|
||||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
|
||||||
cutlass::multiplies, float, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTComputeScaleB =
|
|
||||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
|
||||||
|
|
||||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
|
||||||
cutlass::multiply_add, ElementD, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
public:
|
|
||||||
using EVTCompute =
|
|
||||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
|
||||||
EVTComputeScaleB, Bias>;
|
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
|
||||||
auto azp_adj_args =
|
|
||||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
|
||||||
|
|
||||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
|
||||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
|
||||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
|
||||||
* This epilogue supports per-token azp by computing and applying
|
|
||||||
* the correction term using a rank-1 update. If the term were materialized,
|
|
||||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
|
||||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
|
||||||
* point for each row of A.
|
|
||||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
|
||||||
*
|
|
||||||
* This epilogue also supports bias, which remains per-channel.
|
|
||||||
*/
|
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
|
||||||
struct ScaledEpilogueBiasAzpToken
|
|
||||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
|
||||||
private:
|
|
||||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
|
||||||
using Accum = typename SUPER::Accum;
|
|
||||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
|
||||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
||||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
|
||||||
|
|
||||||
// Per-token azp term, shape (m,1)
|
|
||||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
|
||||||
|
|
||||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
|
||||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
|
||||||
|
|
||||||
// Compute azp * azp_adj
|
|
||||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
|
||||||
cutlass::multiplies, int32_t, int32_t,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTComputeAzp =
|
|
||||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
|
|
||||||
|
|
||||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
|
||||||
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
|
|
||||||
cutlass::minus, float, int32_t,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTComputeAcc =
|
|
||||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
|
||||||
|
|
||||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
|
||||||
cutlass::multiplies, float, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTComputeScaleB =
|
|
||||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
|
||||||
|
|
||||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
|
||||||
cutlass::multiply_add, ElementD, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
public:
|
|
||||||
using EVTCompute =
|
|
||||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
|
||||||
EVTComputeScaleB, Bias>;
|
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
torch::Tensor const& azp,
|
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
|
||||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
|
||||||
auto azp_adj_args =
|
|
||||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
|
||||||
|
|
||||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
|
||||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
|
||||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
|
||||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename ElementAB_, typename ElementD_,
|
template <typename ElementAB_, typename ElementD_,
|
||||||
template <typename, typename, typename> typename Epilogue_,
|
template <typename, typename, typename> typename Epilogue_,
|
||||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||||
@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
if (bias) {
|
if (bias) {
|
||||||
TORCH_CHECK(bias->dtype() == c.dtype(),
|
TORCH_CHECK(bias->dtype() == c.dtype(),
|
||||||
"currently bias dtype must match output dtype ", c.dtype());
|
"currently bias dtype must match output dtype ", c.dtype());
|
||||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBias>(
|
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
|
||||||
c, a, b, a_scales, b_scales, *bias);
|
c, a, b, a_scales, b_scales, *bias);
|
||||||
} else {
|
} else {
|
||||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogue>(c, a, b, a_scales,
|
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
|
||||||
b_scales);
|
c, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
if (azp) {
|
if (azp) {
|
||||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
|
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
|
||||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||||
} else {
|
} else {
|
||||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
|
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
||||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,8 +3,10 @@ import math
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from copy import deepcopy
|
||||||
from typing import List, Optional, Tuple, Union
|
from dataclasses import dataclass, fields
|
||||||
|
from functools import reduce
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
@ -14,7 +16,10 @@ from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag,
|
|||||||
MixedInputKernelScheduleType,
|
MixedInputKernelScheduleType,
|
||||||
TileSchedulerTag,
|
TileSchedulerTag,
|
||||||
TileSchedulerType, VLLMDataType,
|
TileSchedulerType, VLLMDataType,
|
||||||
VLLMDataTypeNames, VLLMDataTypeTag,
|
VLLMDataTypeNames,
|
||||||
|
VLLMDataTypeSize, VLLMDataTypeTag,
|
||||||
|
VLLMDataTypeTorchDataTypeTag,
|
||||||
|
VLLMDataTypeVLLMScalarTypeTag,
|
||||||
VLLMKernelScheduleTag)
|
VLLMKernelScheduleTag)
|
||||||
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
@ -27,49 +32,125 @@ DISPATCH_TEMPLATE = """
|
|||||||
#include "../machete_mm_launcher.cuh"
|
#include "../machete_mm_launcher.cuh"
|
||||||
|
|
||||||
namespace machete {
|
namespace machete {
|
||||||
using GemmDispatcher_ = GemmDispatcher<
|
|
||||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
|
||||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
|
||||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
|
||||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
|
||||||
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
|
|
||||||
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
|
|
||||||
|
|
||||||
{% for s in schedules %}extern torch::Tensor
|
{% for impl_config in impl_configs %}
|
||||||
impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args);
|
{% set type_sig = gen_type_sig(impl_config.types) -%}
|
||||||
{% endfor %}
|
{% for s in impl_config.schedules %}
|
||||||
template <>
|
extern torch::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs);
|
||||||
torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) {
|
{%- endfor %}
|
||||||
|
|
||||||
|
torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
|
||||||
[[maybe_unused]] auto M = args.A.size(0);
|
[[maybe_unused]] auto M = args.A.size(0);
|
||||||
[[maybe_unused]] auto N = args.B.size(1);
|
[[maybe_unused]] auto N = args.B.size(1);
|
||||||
[[maybe_unused]] auto K = args.A.size(1);
|
[[maybe_unused]] auto K = args.A.size(1);
|
||||||
|
|
||||||
if (!args.schedule) {
|
if (!args.maybe_schedule) {
|
||||||
{%- for cond, s in heuristic %}
|
{%- for cond, s in impl_config.heuristic %}
|
||||||
{%if cond is not none%}if ({{cond}})
|
{%if cond is not none%}if ({{cond}})
|
||||||
{%- else %}else
|
{%- else %}else
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %}
|
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %}
|
||||||
}
|
}
|
||||||
|
|
||||||
{% for s in schedules %}
|
{%- for s in impl_config.schedules %}
|
||||||
if (*args.schedule == "{{ gen_sch_name(s) }}") {
|
if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}")
|
||||||
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);
|
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);
|
||||||
}
|
{%- endfor %}
|
||||||
{% endfor %}
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
|
TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
|
||||||
"schedule = ", *args.schedule);
|
"schedule = ", *args.maybe_schedule);
|
||||||
|
}
|
||||||
|
{%- endfor %}
|
||||||
|
|
||||||
|
|
||||||
|
static inline std::optional<at::ScalarType> maybe_scalartype(
|
||||||
|
c10::optional<at::Tensor> const& t) {
|
||||||
|
if (!t) {
|
||||||
|
return std::nullopt;
|
||||||
|
} else {
|
||||||
|
return t->scalar_type();
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
torch::Tensor mm_dispatch(MMArgs args) {
|
||||||
std::vector<std::string> GemmDispatcher_::supported_schedules() {
|
auto out_type = args.maybe_out_type.value_or(args.A.scalar_type());
|
||||||
return {
|
auto a_type = args.A.scalar_type();
|
||||||
{% for s in schedules -%}
|
auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales);
|
||||||
"{{ gen_sch_name(s) }}"{{ ",
|
auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros);
|
||||||
" if not loop.last }}{%- endfor %}
|
auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales);
|
||||||
};
|
auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales);
|
||||||
|
|
||||||
|
{% for impl_config in impl_configs %}
|
||||||
|
{% set t = impl_config.types -%}
|
||||||
|
{% set type_sig = gen_type_sig(t) -%}
|
||||||
|
if (args.b_type == {{VLLMScalarTypeTag[t.b]}}
|
||||||
|
&& a_type == {{TorchTypeTag[t.a]}}
|
||||||
|
&& out_type == {{TorchTypeTag[t.out]}}
|
||||||
|
&& {%if t.b_group_scale != void -%}
|
||||||
|
maybe_g_scales_type == {{TorchTypeTag[t.b_group_scale]}}
|
||||||
|
{%- else %}!maybe_g_scales_type{%endif%}
|
||||||
|
&& {%if t.b_group_zeropoint != void -%}
|
||||||
|
maybe_g_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}}
|
||||||
|
{%- else %}!maybe_g_zeros_type{%endif%}
|
||||||
|
&& {%if t.b_channel_scale != void -%}
|
||||||
|
maybe_ch_scales_type == {{TorchTypeTag[t.b_channel_scale]}}
|
||||||
|
{%- else %}!maybe_ch_scales_type{%endif%}
|
||||||
|
&& {%if t.a_token_scale != void -%}
|
||||||
|
maybe_tok_scales_type == {{TorchTypeTag[t.a_token_scale]}}
|
||||||
|
{%- else %}!maybe_tok_scales_type{%endif%}
|
||||||
|
) {
|
||||||
|
return mm_dispatch_{{type_sig}}(args);
|
||||||
|
}
|
||||||
|
{%- endfor %}
|
||||||
|
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false, "machete_mm(..) is not implemented for "
|
||||||
|
"a_type=", args.A.scalar_type(),
|
||||||
|
", b_type=", args.b_type.str(),
|
||||||
|
", out_type=", out_type,
|
||||||
|
", with_group_scale_type=", maybe_g_scales_type
|
||||||
|
? toString(*maybe_g_scales_type) : "None",
|
||||||
|
", with_group_zeropoint_type=", maybe_g_zeros_type
|
||||||
|
? toString(*maybe_g_zeros_type) : "None",
|
||||||
|
", with_channel_scale_type=", maybe_ch_scales_type
|
||||||
|
? toString(*maybe_ch_scales_type) : "None",
|
||||||
|
", with_token_scale_type=", maybe_tok_scales_type
|
||||||
|
? toString(*maybe_tok_scales_type) : "None",
|
||||||
|
"; implemented types are: \\n",
|
||||||
|
{%- for impl_config in impl_configs %}
|
||||||
|
{% set t = impl_config.types -%}
|
||||||
|
"\\t{{gen_type_option_name(t)}}\\n",
|
||||||
|
{%- endfor %}
|
||||||
|
"");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> supported_schedules_dispatch(
|
||||||
|
SupportedSchedulesArgs args) {
|
||||||
|
auto out_type = args.maybe_out_type.value_or(args.a_type);
|
||||||
|
|
||||||
|
{% for impl_config in impl_configs %}
|
||||||
|
{% set t = impl_config.types -%}
|
||||||
|
{% set schs = impl_config.schedules -%}
|
||||||
|
if (args.b_type == {{VLLMScalarTypeTag[t.b]}}
|
||||||
|
&& args.a_type == {{TorchTypeTag[t.a]}}
|
||||||
|
&& out_type == {{TorchTypeTag[t.out]}}
|
||||||
|
&& {%if t.b_group_scale != void -%}
|
||||||
|
args.maybe_group_scales_type == {{TorchTypeTag[t.b_group_scale]}}
|
||||||
|
{%- else %}!args.maybe_group_scales_type{%endif%}
|
||||||
|
&& {%if t.b_group_zeropoint != void-%}
|
||||||
|
args.maybe_group_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}}
|
||||||
|
{%- else %}!args.maybe_group_zeros_type{%endif%}
|
||||||
|
) {
|
||||||
|
return {
|
||||||
|
{%- for s in impl_config.schedules %}
|
||||||
|
"{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
{%- endfor %}
|
||||||
|
|
||||||
|
return {};
|
||||||
|
};
|
||||||
|
|
||||||
}; // namespace machete
|
}; // namespace machete
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -77,20 +158,10 @@ IMPL_TEMPLATE = """
|
|||||||
#include "../machete_mm_launcher.cuh"
|
#include "../machete_mm_launcher.cuh"
|
||||||
|
|
||||||
namespace machete {
|
namespace machete {
|
||||||
template <typename Config, bool with_C, bool with_scales, bool with_zeropoints>
|
|
||||||
using Kernel = MacheteKernelTemplate<
|
{% for sch in unique_schedules(impl_configs) %}
|
||||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
{% set sch_sig = gen_sch_sig(sch) -%}
|
||||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
struct sch_{{sch_sig}} {
|
||||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
|
||||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
|
||||||
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
|
|
||||||
{{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints
|
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
|
|
||||||
Config, with_C, with_scales, with_zeropoints>;
|
|
||||||
|
|
||||||
{% for sch in schedules %}
|
|
||||||
{% set schedule_name = gen_sch_name(sch) -%}
|
|
||||||
struct sch_{{schedule_name}} {
|
|
||||||
using TileShapeNM = Shape<{{
|
using TileShapeNM = Shape<{{
|
||||||
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
|
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
|
||||||
using ClusterShape = Shape<{{
|
using ClusterShape = Shape<{{
|
||||||
@ -101,27 +172,34 @@ struct sch_{{schedule_name}} {
|
|||||||
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
|
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
|
||||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||||
};
|
};
|
||||||
|
|
||||||
torch::Tensor
|
|
||||||
impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) {
|
|
||||||
bool with_C = args.C.has_value(), with_scales = args.scales.has_value(),
|
|
||||||
with_zeropoints = args.zeros.has_value();
|
|
||||||
|
|
||||||
{% for s in specializations %}
|
|
||||||
if (with_C == {{s.with_C|lower}}
|
|
||||||
&& with_zeropoints == {{s.with_zeropoints|lower}}
|
|
||||||
&& with_scales == {{s.with_scales|lower}}) {
|
|
||||||
return run_impl<Kernel<sch_{{schedule_name}}, {{s.with_C|lower}},
|
|
||||||
{{s.with_scales|lower}}, {{s.with_zeropoints|lower}}>>(args);
|
|
||||||
}{% endfor %}
|
|
||||||
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false, "for the sake of compile times and binary size machete_mm(..) is "
|
|
||||||
" not implemented for with_C=", with_C, ", with_scales=", with_scales,
|
|
||||||
", with_zeropoints=", with_zeropoints,
|
|
||||||
" (for {{type_name}}_sch_{{schedule_name}})");
|
|
||||||
}
|
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
|
{% for impl_config in impl_configs %}
|
||||||
|
{% set t = impl_config.types -%}
|
||||||
|
{% set schs = impl_config.schedules -%}
|
||||||
|
{% set type_sig = gen_type_sig(t) -%}
|
||||||
|
|
||||||
|
template<typename Sch>
|
||||||
|
using Kernel_{{type_sig}} = MacheteKernelTemplate<
|
||||||
|
{{DataTypeTag[t.a]}}, // ElementA
|
||||||
|
{{DataTypeTag[t.b]}}, // ElementB
|
||||||
|
{{DataTypeTag[t.out]}}, // ElementD
|
||||||
|
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||||
|
{{DataTypeTag[t.b_group_scale]}}, // GroupScaleT
|
||||||
|
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
||||||
|
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
||||||
|
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
|
||||||
|
Sch>;
|
||||||
|
|
||||||
|
{% for sch in schs %}
|
||||||
|
{% set sch_sig = gen_sch_sig(sch) -%}
|
||||||
|
torch::Tensor
|
||||||
|
impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) {
|
||||||
|
return run_impl<Kernel_{{type_sig}}<sch_{{sch_sig}}>>(args);
|
||||||
|
}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endfor %}
|
||||||
|
|
||||||
}; // namespace machete
|
}; // namespace machete
|
||||||
"""
|
"""
|
||||||
@ -130,26 +208,34 @@ PREPACK_TEMPLATE = """
|
|||||||
#include "../machete_prepack_launcher.cuh"
|
#include "../machete_prepack_launcher.cuh"
|
||||||
|
|
||||||
namespace machete {
|
namespace machete {
|
||||||
using PrepackBDispatcher_ = PrepackBDispatcher<
|
|
||||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
|
||||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
|
||||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
|
||||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
|
||||||
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
|
|
||||||
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
|
|
||||||
|
|
||||||
using PrepackedLayoutB = PrepackedLayoutBTemplate<
|
torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
auto convert_type = args.maybe_group_scales_type.value_or(args.a_type);
|
||||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
{%- for t in types %}
|
||||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
{% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %}
|
||||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
if (args.a_type == {{TorchTypeTag[t.a]}}
|
||||||
cutlass::layout::ColumnMajor,
|
&& args.b_type.size_bits() == {{t.b_num_bits}}
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>;
|
&& convert_type == {{TorchTypeTag[t.convert]}}) {
|
||||||
|
return prepack_impl<
|
||||||
template <>
|
PrepackedLayoutBTemplate<
|
||||||
torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) {
|
{{DataTypeTag[t.a]}}, // ElementA
|
||||||
return prepack_impl<PrepackedLayoutB>(B);
|
{{DataTypeTag[b_type]}}, // ElementB
|
||||||
|
{{DataTypeTag[t.convert]}}, // ElementConvert
|
||||||
|
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||||
|
cutlass::layout::ColumnMajor,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>
|
||||||
|
>(args.B);
|
||||||
|
}
|
||||||
|
{%- endfor %}
|
||||||
|
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||||
|
"prepack_B_dispatch(..) is not implemented for "
|
||||||
|
"atype = ", args.a_type,
|
||||||
|
", b_type = ", args.b_type.str(),
|
||||||
|
", with_group_scales_type= ", args.maybe_group_scales_type ?
|
||||||
|
toString(*args.maybe_group_scales_type) : "None");
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // namespace machete
|
}; // namespace machete
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -166,32 +252,34 @@ class ScheduleConfig:
|
|||||||
tile_scheduler: TileSchedulerType
|
tile_scheduler: TileSchedulerType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class TypeConfig:
|
class TypeConfig:
|
||||||
element_a: DataType
|
a: DataType
|
||||||
element_b: Union[DataType, VLLMDataType]
|
b: Union[DataType, VLLMDataType]
|
||||||
element_b_scale: DataType
|
b_group_scale: DataType
|
||||||
element_b_zeropoint: DataType
|
b_group_zeropoint: DataType
|
||||||
element_d: DataType
|
b_channel_scale: DataType
|
||||||
|
a_token_scale: DataType
|
||||||
|
out: DataType
|
||||||
|
accumulator: DataType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PrepackTypeConfig:
|
||||||
|
a: DataType
|
||||||
|
b_num_bits: int
|
||||||
|
convert: DataType
|
||||||
accumulator: DataType
|
accumulator: DataType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Specialization:
|
|
||||||
with_C: bool
|
|
||||||
with_zeropoints: bool
|
|
||||||
with_scales: bool
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ImplConfig:
|
class ImplConfig:
|
||||||
type_config: TypeConfig
|
types: TypeConfig
|
||||||
schedule_configs: List[ScheduleConfig]
|
schedules: List[ScheduleConfig]
|
||||||
specializations: List[Specialization]
|
|
||||||
heuristic: List[Tuple[Optional[str], ScheduleConfig]]
|
heuristic: List[Tuple[Optional[str], ScheduleConfig]]
|
||||||
|
|
||||||
|
|
||||||
def generate_schedule_name(schedule_config: ScheduleConfig) -> str:
|
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||||
tile_shape = (
|
tile_shape = (
|
||||||
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
||||||
)
|
)
|
||||||
@ -209,40 +297,34 @@ def generate_schedule_name(schedule_config: ScheduleConfig) -> str:
|
|||||||
f"_{epilogue_schedule}_{tile_scheduler}")
|
f"_{epilogue_schedule}_{tile_scheduler}")
|
||||||
|
|
||||||
|
|
||||||
# mostly unique shorter schedule_name
|
# mostly unique shorter sch_sig
|
||||||
def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str:
|
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||||
kernel_terse_names_replace = {
|
kernel_terse_names_replace = {
|
||||||
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
|
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
|
||||||
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
||||||
"StreamKScheduler": "streamK",
|
"StreamKScheduler": "streamK",
|
||||||
}
|
}
|
||||||
|
|
||||||
schedule_name = generate_schedule_name(schedule_config)
|
sch_sig = generate_sch_sig(schedule_config)
|
||||||
for orig, terse in kernel_terse_names_replace.items():
|
for orig, terse in kernel_terse_names_replace.items():
|
||||||
schedule_name = schedule_name.replace(orig, terse)
|
sch_sig = sch_sig.replace(orig, terse)
|
||||||
return schedule_name
|
return sch_sig
|
||||||
|
|
||||||
|
|
||||||
# unique type_name
|
# unique type_name
|
||||||
def generate_type_signature(kernel_type_config: TypeConfig):
|
def generate_type_signature(kernel_types: TypeConfig):
|
||||||
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
|
return str("".join([
|
||||||
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
|
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||||
element_d = VLLMDataTypeNames[kernel_type_config.element_d]
|
for field in fields(TypeConfig)
|
||||||
accumulator = VLLMDataTypeNames[kernel_type_config.accumulator]
|
]))
|
||||||
element_scale = VLLMDataTypeNames[kernel_type_config.element_b_scale]
|
|
||||||
element_zeropoint = VLLMDataTypeNames[
|
|
||||||
kernel_type_config.element_b_zeropoint]
|
|
||||||
|
|
||||||
return (f"{element_a}{element_b}{element_d}"
|
|
||||||
f"{accumulator}{element_scale}{element_zeropoint}")
|
|
||||||
|
|
||||||
|
|
||||||
# non-unique shorter type_name
|
def generate_type_option_name(kernel_types: TypeConfig):
|
||||||
def generate_terse_type_signature(kernel_type_config: TypeConfig):
|
return ", ".join([
|
||||||
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
|
f"{field.name.replace('b_', 'with_')+'_type'}=" +
|
||||||
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
|
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||||
|
for field in fields(TypeConfig)
|
||||||
return f"{element_a}{element_b}"
|
])
|
||||||
|
|
||||||
|
|
||||||
def is_power_of_two(n):
|
def is_power_of_two(n):
|
||||||
@ -263,13 +345,36 @@ def to_cute_constant(value: List[int]):
|
|||||||
return _to_cute_constant(value)
|
return _to_cute_constant(value)
|
||||||
|
|
||||||
|
|
||||||
|
def unique_schedules(impl_configs: List[ImplConfig]):
|
||||||
|
return list(
|
||||||
|
set(sch for impl_config in impl_configs
|
||||||
|
for sch in impl_config.schedules))
|
||||||
|
|
||||||
|
|
||||||
|
def unsigned_type_with_bitwidth(num_bits):
|
||||||
|
return {
|
||||||
|
4: DataType.u4,
|
||||||
|
8: DataType.u8,
|
||||||
|
16: DataType.u16,
|
||||||
|
32: DataType.u32,
|
||||||
|
64: DataType.u64,
|
||||||
|
}[num_bits]
|
||||||
|
|
||||||
|
|
||||||
template_globals = {
|
template_globals = {
|
||||||
|
"void": DataType.void,
|
||||||
"DataTypeTag": VLLMDataTypeTag,
|
"DataTypeTag": VLLMDataTypeTag,
|
||||||
|
"VLLMScalarTypeTag": VLLMDataTypeVLLMScalarTypeTag,
|
||||||
|
"TorchTypeTag": VLLMDataTypeTorchDataTypeTag,
|
||||||
"KernelScheduleTag": VLLMKernelScheduleTag,
|
"KernelScheduleTag": VLLMKernelScheduleTag,
|
||||||
"EpilogueScheduleTag": EpilogueScheduleTag,
|
"EpilogueScheduleTag": EpilogueScheduleTag,
|
||||||
"TileSchedulerTag": TileSchedulerTag,
|
"TileSchedulerTag": TileSchedulerTag,
|
||||||
"to_cute_constant": to_cute_constant,
|
"to_cute_constant": to_cute_constant,
|
||||||
"gen_sch_name": generate_terse_schedule_name,
|
"gen_sch_sig": generate_terse_sch_sig,
|
||||||
|
"gen_type_sig": generate_type_signature,
|
||||||
|
"unique_schedules": unique_schedules,
|
||||||
|
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
||||||
|
"gen_type_option_name": generate_type_option_name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -284,42 +389,82 @@ mm_impl_template = create_template(IMPL_TEMPLATE)
|
|||||||
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
||||||
|
|
||||||
|
|
||||||
def create_sources(impl_config: ImplConfig, num_impl_files=1):
|
def create_sources(impl_configs: List[ImplConfig], num_impl_files=8):
|
||||||
sources = []
|
sources = []
|
||||||
|
|
||||||
type_name = generate_type_signature(impl_config.type_config)
|
|
||||||
terse_type_name = generate_terse_type_signature(impl_config.type_config)
|
|
||||||
|
|
||||||
sources.append((
|
sources.append((
|
||||||
f"machete_mm_{terse_type_name}",
|
"machete_mm_dispatch",
|
||||||
mm_dispatch_template.render(type_name=type_name,
|
mm_dispatch_template.render(impl_configs=impl_configs),
|
||||||
type_config=impl_config.type_config,
|
|
||||||
schedules=impl_config.schedule_configs,
|
|
||||||
heuristic=impl_config.heuristic),
|
|
||||||
))
|
))
|
||||||
|
|
||||||
|
prepack_types = []
|
||||||
|
for impl_config in impl_configs:
|
||||||
|
convert_type = impl_config.types.a \
|
||||||
|
if impl_config.types.b_group_scale == DataType.void \
|
||||||
|
else impl_config.types.b_group_scale
|
||||||
|
prepack_types.append(
|
||||||
|
PrepackTypeConfig(
|
||||||
|
a=impl_config.types.a,
|
||||||
|
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
|
||||||
|
convert=convert_type,
|
||||||
|
accumulator=impl_config.types.accumulator,
|
||||||
|
))
|
||||||
|
|
||||||
|
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
||||||
|
# For now we we can just use the first accumulator type seen since
|
||||||
|
# the tensor core shapes/layouts don't vary based on accumulator
|
||||||
|
# type so we can generate less code this way
|
||||||
|
return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert)
|
||||||
|
|
||||||
|
unique_prepack_types = []
|
||||||
|
prepack_types_seen = set()
|
||||||
|
for prepack_type in prepack_types:
|
||||||
|
key = prepacked_type_key(prepack_type)
|
||||||
|
if key not in prepack_types_seen:
|
||||||
|
unique_prepack_types.append(prepack_type)
|
||||||
|
prepack_types_seen.add(key)
|
||||||
|
|
||||||
sources.append((
|
sources.append((
|
||||||
f"machete_prepack_{terse_type_name}",
|
"machete_prepack",
|
||||||
prepack_dispatch_template.render(
|
prepack_dispatch_template.render(types=unique_prepack_types, ),
|
||||||
type_name=type_name,
|
|
||||||
type_config=impl_config.type_config,
|
|
||||||
),
|
|
||||||
))
|
))
|
||||||
|
|
||||||
num_schedules = len(impl_config.schedule_configs)
|
# Split up impls across files
|
||||||
schedules_per_file = math.ceil(num_schedules / num_impl_files)
|
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
||||||
for part, i in enumerate(range(0, num_schedules, schedules_per_file)):
|
num_impls_per_file = math.ceil(num_impls / num_impl_files)
|
||||||
file_schedules = impl_config.schedule_configs[i:i + schedules_per_file]
|
|
||||||
|
|
||||||
|
files_impls: List[List[ImplConfig]] = [[]]
|
||||||
|
|
||||||
|
curr_num_impls_assigned = 0
|
||||||
|
curr_impl_in_file = 0
|
||||||
|
curr_impl_configs = deepcopy(list(reversed(impl_configs)))
|
||||||
|
|
||||||
|
while curr_num_impls_assigned < num_impls:
|
||||||
|
room_left_in_file = num_impls_per_file - curr_impl_in_file
|
||||||
|
if room_left_in_file == 0:
|
||||||
|
files_impls.append([])
|
||||||
|
room_left_in_file = num_impls_per_file
|
||||||
|
curr_impl_in_file = 0
|
||||||
|
|
||||||
|
curr_ic = curr_impl_configs[-1]
|
||||||
|
if len(curr_ic.schedules) >= room_left_in_file:
|
||||||
|
# Break apart the current impl config
|
||||||
|
tmp_ic = deepcopy(curr_ic)
|
||||||
|
tmp_ic.schedules = curr_ic.schedules[:room_left_in_file]
|
||||||
|
curr_ic.schedules = curr_ic.schedules[room_left_in_file:]
|
||||||
|
files_impls[-1].append(tmp_ic)
|
||||||
|
else:
|
||||||
|
files_impls[-1].append(curr_ic)
|
||||||
|
curr_impl_configs.pop()
|
||||||
|
curr_num_impls_assigned += len(files_impls[-1][-1].schedules)
|
||||||
|
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
||||||
|
|
||||||
|
for part, file_impls in enumerate(files_impls):
|
||||||
sources.append((
|
sources.append((
|
||||||
f"machete_mm_{terse_type_name}_impl_part{part}",
|
f"machete_mm_impl_part{part+1}",
|
||||||
mm_impl_template.render(
|
mm_impl_template.render(impl_configs=file_impls),
|
||||||
type_name=type_name,
|
|
||||||
type_config=impl_config.type_config,
|
|
||||||
schedules=file_schedules,
|
|
||||||
specializations=impl_config.specializations,
|
|
||||||
),
|
|
||||||
))
|
))
|
||||||
|
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
|
|
||||||
@ -328,187 +473,169 @@ def generate():
|
|||||||
# about how this works
|
# about how this works
|
||||||
SCRIPT_DIR = os.path.dirname(__file__)
|
SCRIPT_DIR = os.path.dirname(__file__)
|
||||||
|
|
||||||
schedule_common_params = dict(
|
sch_common_params = dict(
|
||||||
kernel_schedule=TmaMI,
|
kernel_schedule=TmaMI,
|
||||||
epilogue_schedule=TmaCoop,
|
epilogue_schedule=TmaCoop,
|
||||||
tile_scheduler=TileSchedulerType.StreamK,
|
tile_scheduler=TileSchedulerType.StreamK,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||||
|
default_tile_heuristic_config = {
|
||||||
|
#### M = 257+
|
||||||
|
"M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||||
|
"M > 256": ((128, 256), (2, 1, 1)),
|
||||||
|
#### M = 129-256
|
||||||
|
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||||
|
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||||
|
"M > 128": ((128, 256), (2, 1, 1)),
|
||||||
|
#### M = 65-128
|
||||||
|
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||||
|
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||||
|
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||||
|
"M > 64": ((128, 128), (2, 1, 1)),
|
||||||
|
#### M = 33-64
|
||||||
|
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||||
|
"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||||
|
"M > 32": ((128, 64), (2, 1, 1)),
|
||||||
|
#### M = 17-32
|
||||||
|
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||||
|
"M > 16": ((256, 32), (2, 1, 1)),
|
||||||
|
#### M = 1-16
|
||||||
|
"N >= 26624": ((256, 16), (1, 1, 1)),
|
||||||
|
None: ((128, 16), (1, 1, 1)),
|
||||||
|
}
|
||||||
|
|
||||||
# For now we use the same heuristic for all types
|
# For now we use the same heuristic for all types
|
||||||
# Heuristic is currently tuned for H100s
|
# Heuristic is currently tuned for H100s
|
||||||
default_heuristic = [
|
default_heuristic = [
|
||||||
#### M = 257+
|
(cond, ScheduleConfig(*tile_config,
|
||||||
(
|
**sch_common_params)) # type: ignore
|
||||||
"M > 256 && K <= 16384 && N <= 4096",
|
for cond, tile_config in default_tile_heuristic_config.items()
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(128, 128),
|
|
||||||
cluster_shape_mnk=(2, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
(
|
|
||||||
"M > 256",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(128, 256),
|
|
||||||
cluster_shape_mnk=(2, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
#### M = 129-256
|
|
||||||
(
|
|
||||||
"M > 128 && K <= 4096 && N <= 4096",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(128, 64),
|
|
||||||
cluster_shape_mnk=(2, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
(
|
|
||||||
"M > 128 && K <= 8192 && N <= 8192",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(128, 128),
|
|
||||||
cluster_shape_mnk=(2, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
(
|
|
||||||
"M > 128",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(128, 256),
|
|
||||||
cluster_shape_mnk=(2, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
#### M = 65-128
|
|
||||||
(
|
|
||||||
"M > 64 && K <= 4069 && N <= 4069",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(128, 32),
|
|
||||||
cluster_shape_mnk=(2, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
(
|
|
||||||
"M > 64 && K <= 4069 && N <= 8192",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(128, 64),
|
|
||||||
cluster_shape_mnk=(2, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
(
|
|
||||||
"M > 64 && K >= 8192 && N >= 12288",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(256, 128),
|
|
||||||
cluster_shape_mnk=(2, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
(
|
|
||||||
"M > 64",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(128, 128),
|
|
||||||
cluster_shape_mnk=(2, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
#### M = 33-64
|
|
||||||
(
|
|
||||||
"M > 32 && K <= 6144 && N <= 6144",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(128, 16),
|
|
||||||
cluster_shape_mnk=(1, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
(
|
|
||||||
"M > 32 && K >= 16384 && N >= 12288",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(256, 64),
|
|
||||||
cluster_shape_mnk=(2, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
(
|
|
||||||
"M > 32",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(128, 64),
|
|
||||||
cluster_shape_mnk=(2, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
#### M = 17-32
|
|
||||||
(
|
|
||||||
"M > 16 && K <= 12288 && N <= 8192",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(128, 32),
|
|
||||||
cluster_shape_mnk=(2, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
(
|
|
||||||
"M > 16",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(256, 32),
|
|
||||||
cluster_shape_mnk=(2, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
#### M = 1-16
|
|
||||||
(
|
|
||||||
"N >= 26624",
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(256, 16),
|
|
||||||
cluster_shape_mnk=(1, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
(
|
|
||||||
None,
|
|
||||||
ScheduleConfig(
|
|
||||||
tile_shape_mn=(128, 16),
|
|
||||||
cluster_shape_mnk=(1, 1, 1),
|
|
||||||
**schedule_common_params # type: ignore
|
|
||||||
)),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Do not use schedules = list(set(...)) because we need to make sure
|
def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]):
|
||||||
# the output list is deterministic; otherwise the generated kernel file
|
# Do not use schedules = list(set(...)) because we need to make sure
|
||||||
# will be non-deterministic and causes ccache miss.
|
# the output list is deterministic; otherwise the generated kernel file
|
||||||
schedules = []
|
# will be non-deterministic and causes ccache miss.
|
||||||
for _, schedule_config in default_heuristic:
|
schedules = []
|
||||||
if schedule_config not in schedules:
|
for _, schedule_config in heuristic:
|
||||||
schedules.append(schedule_config)
|
if schedule_config not in schedules:
|
||||||
|
schedules.append(schedule_config)
|
||||||
|
return schedules
|
||||||
|
|
||||||
impl_configs = []
|
impl_configs = []
|
||||||
|
|
||||||
GPTQ_kernel_type_configs = list(
|
GPTQ_kernel_type_configs = list(
|
||||||
TypeConfig(
|
TypeConfig(
|
||||||
element_a=element_a,
|
a=a,
|
||||||
element_b=element_b,
|
b=b,
|
||||||
element_b_scale=element_a,
|
b_group_scale=a,
|
||||||
element_b_zeropoint=element_a,
|
b_group_zeropoint=DataType.void,
|
||||||
element_d=element_a,
|
b_channel_scale=DataType.void,
|
||||||
|
a_token_scale=DataType.void,
|
||||||
|
out=a,
|
||||||
accumulator=DataType.f32,
|
accumulator=DataType.f32,
|
||||||
) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
||||||
for element_a in (DataType.f16, DataType.bf16))
|
for a in (DataType.f16, DataType.bf16))
|
||||||
|
|
||||||
GPTQ_kernel_specializations = [
|
|
||||||
Specialization(with_C=False, with_zeropoints=False, with_scales=True)
|
|
||||||
]
|
|
||||||
|
|
||||||
impl_configs += [
|
impl_configs += [
|
||||||
ImplConfig(x[0], x[1], x[2], x[3])
|
ImplConfig(x[0], x[1], x[2])
|
||||||
for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules),
|
for x in zip(GPTQ_kernel_type_configs,
|
||||||
itertools.repeat(GPTQ_kernel_specializations),
|
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||||
itertools.repeat(default_heuristic))
|
itertools.repeat(default_heuristic))
|
||||||
]
|
]
|
||||||
|
|
||||||
AWQ_kernel_type_configs = list(
|
AWQ_kernel_type_configs = list(
|
||||||
TypeConfig(
|
TypeConfig(
|
||||||
element_a=element_a,
|
a=a,
|
||||||
element_b=element_b,
|
b=b,
|
||||||
element_b_scale=element_a,
|
b_group_scale=a,
|
||||||
element_b_zeropoint=element_a,
|
b_group_zeropoint=a,
|
||||||
element_d=element_a,
|
b_channel_scale=DataType.void,
|
||||||
|
a_token_scale=DataType.void,
|
||||||
|
out=a,
|
||||||
accumulator=DataType.f32,
|
accumulator=DataType.f32,
|
||||||
) for element_b in (DataType.u4, DataType.u8)
|
) for b in (DataType.u4, DataType.u8)
|
||||||
for element_a in (DataType.f16, DataType.bf16))
|
for a in (DataType.f16, DataType.bf16))
|
||||||
|
|
||||||
AWQ_kernel_specializations = [
|
impl_configs += [
|
||||||
Specialization(with_C=False, with_zeropoints=True, with_scales=True)
|
ImplConfig(x[0], x[1], x[2])
|
||||||
|
for x in zip(AWQ_kernel_type_configs,
|
||||||
|
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||||
|
itertools.repeat(default_heuristic))
|
||||||
|
]
|
||||||
|
|
||||||
|
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||||
|
# TODO (LucasWilkinson): Further tuning required
|
||||||
|
qqq_tile_heuristic_config = {
|
||||||
|
#### M = 257+
|
||||||
|
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||||
|
# TODO (LucasWilkinson): Investigate further
|
||||||
|
# "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||||
|
# "M > 256": ((128, 256), (2, 1, 1)),
|
||||||
|
"M > 256": ((128, 128), (2, 1, 1)),
|
||||||
|
#### M = 129-256
|
||||||
|
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||||
|
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||||
|
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||||
|
# TODO (LucasWilkinson): Investigate further
|
||||||
|
# "M > 128": ((128, 256), (2, 1, 1)),
|
||||||
|
"M > 128": ((128, 128), (2, 1, 1)),
|
||||||
|
#### M = 65-128
|
||||||
|
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||||
|
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||||
|
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||||
|
"M > 64": ((128, 128), (2, 1, 1)),
|
||||||
|
#### M = 33-64
|
||||||
|
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||||
|
# Broken for QQQ types
|
||||||
|
# TODO (LucasWilkinson): Investigate further
|
||||||
|
#"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||||
|
"M > 32": ((128, 64), (2, 1, 1)),
|
||||||
|
#### M = 17-32
|
||||||
|
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||||
|
"M > 16": ((256, 32), (2, 1, 1)),
|
||||||
|
#### M = 1-16
|
||||||
|
"N >= 26624": ((256, 16), (1, 1, 1)),
|
||||||
|
None: ((128, 16), (1, 1, 1)),
|
||||||
|
}
|
||||||
|
|
||||||
|
# For now we use the same heuristic for all types
|
||||||
|
# Heuristic is currently tuned for H100s
|
||||||
|
qqq_heuristic = [
|
||||||
|
(cond, ScheduleConfig(*tile_config,
|
||||||
|
**sch_common_params)) # type: ignore
|
||||||
|
for cond, tile_config in qqq_tile_heuristic_config.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
QQQ_kernel_types = [
|
||||||
|
*(TypeConfig(
|
||||||
|
a=DataType.s8,
|
||||||
|
b=VLLMDataType.u4b8,
|
||||||
|
b_group_scale=b_group_scale,
|
||||||
|
b_group_zeropoint=DataType.void,
|
||||||
|
b_channel_scale=DataType.f32,
|
||||||
|
a_token_scale=DataType.f32,
|
||||||
|
out=DataType.f16,
|
||||||
|
accumulator=DataType.s32,
|
||||||
|
) for b_group_scale in (DataType.f16, DataType.void)),
|
||||||
|
*(TypeConfig(
|
||||||
|
a=DataType.e4m3,
|
||||||
|
b=VLLMDataType.u4b8,
|
||||||
|
b_group_scale=b_group_scale,
|
||||||
|
b_group_zeropoint=DataType.void,
|
||||||
|
b_channel_scale=DataType.f32,
|
||||||
|
a_token_scale=DataType.f32,
|
||||||
|
out=DataType.f16,
|
||||||
|
accumulator=DataType.f32,
|
||||||
|
) for b_group_scale in (DataType.f16, DataType.void)),
|
||||||
]
|
]
|
||||||
|
|
||||||
impl_configs += [
|
impl_configs += [
|
||||||
ImplConfig(x[0], x[1], x[2], x[3])
|
ImplConfig(x[0], x[1], x[2])
|
||||||
for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules),
|
for x in zip(QQQ_kernel_types,
|
||||||
itertools.repeat(AWQ_kernel_specializations),
|
itertools.repeat(get_unique_schedules(qqq_heuristic)),
|
||||||
itertools.repeat(default_heuristic))
|
itertools.repeat(qqq_heuristic))
|
||||||
]
|
]
|
||||||
|
|
||||||
output_dir = os.path.join(SCRIPT_DIR, "generated")
|
output_dir = os.path.join(SCRIPT_DIR, "generated")
|
||||||
@ -521,12 +648,11 @@ def generate():
|
|||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
# Render each group of configurations into separate files
|
# Render each group of configurations into separate files
|
||||||
for impl_config in impl_configs:
|
for filename, code in create_sources(impl_configs):
|
||||||
for filename, code in create_sources(impl_config):
|
filepath = os.path.join(output_dir, f"{filename}.cu")
|
||||||
filepath = os.path.join(output_dir, f"{filename}.cu")
|
with open(filepath, "w") as output_file:
|
||||||
with open(filepath, "w") as output_file:
|
output_file.write(code)
|
||||||
output_file.write(code)
|
print(f"Rendered template to {filepath}")
|
||||||
print(f"Rendered template to {filepath}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -171,6 +171,10 @@ struct MacheteCollectiveMma {
|
|||||||
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
|
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
|
||||||
Int<DispatchPolicy::Stages>{})));
|
Int<DispatchPolicy::Stages>{})));
|
||||||
|
|
||||||
|
using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy(
|
||||||
|
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
|
||||||
|
Int<DispatchPolicy::Stages>{})));
|
||||||
|
|
||||||
using SmemLayoutAtomARowMajor =
|
using SmemLayoutAtomARowMajor =
|
||||||
decltype(rs_smem_selector<GmmaMajorA, ElementA,
|
decltype(rs_smem_selector<GmmaMajorA, ElementA,
|
||||||
decltype(cute::get<0>(TileShape_MNK{})),
|
decltype(cute::get<0>(TileShape_MNK{})),
|
||||||
@ -288,14 +292,7 @@ struct MacheteCollectiveMma {
|
|||||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0,
|
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0,
|
||||||
"SmemLayoutAtomScale must evenly divide tile k shape.");
|
"SmemLayoutAtomScale must evenly divide tile k shape.");
|
||||||
|
|
||||||
// Tile along modes in a way that maximizes the TMA box size.
|
// Tile along modes in a way that maximizes the TMA box size
|
||||||
using SmemLayoutACopy = decltype(tile_to_shape(
|
|
||||||
SmemLayoutAtomARowMajor{},
|
|
||||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
|
|
||||||
Int<DispatchPolicy::Stages>{}),
|
|
||||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
|
|
||||||
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
|
||||||
|
|
||||||
using SmemLayoutB = decltype(tile_to_shape(
|
using SmemLayoutB = decltype(tile_to_shape(
|
||||||
SmemLayoutAtomB{},
|
SmemLayoutAtomB{},
|
||||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
|
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
|
||||||
@ -428,12 +425,12 @@ struct MacheteCollectiveMma {
|
|||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
|
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
|
||||||
using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset(
|
using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset_copy(
|
||||||
make_shape(int32_t(0), int32_t(0), int32_t(0)))));
|
make_shape(int32_t(0), int32_t(0), int32_t(0)))));
|
||||||
|
|
||||||
using ATensor = decltype(make_tensor(
|
using ATensor = decltype(make_tensor(
|
||||||
get_logical_ptr(static_cast<InternalElementA const*>(nullptr)),
|
get_logical_ptr(static_cast<InternalElementA const*>(nullptr)),
|
||||||
shape(GmemLayoutA::TVbNbKL_to_offset(
|
shape(GmemLayoutA::TVbNbKL_to_offset_copy(
|
||||||
make_shape(int32_t(0), int32_t(0), int32_t(0)))),
|
make_shape(int32_t(0), int32_t(0), int32_t(0)))),
|
||||||
PrepackedStrideA{}));
|
PrepackedStrideA{}));
|
||||||
|
|
||||||
@ -450,8 +447,8 @@ struct MacheteCollectiveMma {
|
|||||||
|
|
||||||
static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) {
|
static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) {
|
||||||
return make_tma_copy<TmaElementA>(
|
return make_tma_copy<TmaElementA>(
|
||||||
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
|
GmemTiledCopyA{}, tensor_a, SmemLayoutACopy{}(_, _, cute::Int<0>{}),
|
||||||
shape(SmemLayoutA{}(_, _, cute::Int<0>{})),
|
shape(SmemLayoutACopy{}(_, _, cute::Int<0>{})),
|
||||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -584,7 +581,7 @@ struct MacheteCollectiveMma {
|
|||||||
typename Params::TMA_Scale tma_load_scale;
|
typename Params::TMA_Scale tma_load_scale;
|
||||||
typename Params::TMA_Zero tma_load_zero;
|
typename Params::TMA_Zero tma_load_zero;
|
||||||
|
|
||||||
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
|
auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L));
|
||||||
tma_load_a = make_tma_copy_A(
|
tma_load_a = make_tma_copy_A(
|
||||||
make_logical_tensor(ptr_A, shape(layout), stride(layout)));
|
make_logical_tensor(ptr_A, shape(layout), stride(layout)));
|
||||||
|
|
||||||
@ -722,7 +719,7 @@ struct MacheteCollectiveMma {
|
|||||||
// (TILE_V,TILE_B,m,k,l)
|
// (TILE_V,TILE_B,m,k,l)
|
||||||
auto make_gA_mkl = [&]() {
|
auto make_gA_mkl = [&]() {
|
||||||
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
|
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
|
||||||
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
|
auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L));
|
||||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout));
|
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout));
|
||||||
return local_tile(mA_mkl,
|
return local_tile(mA_mkl,
|
||||||
make_shape(size<0>(layout), PPBlocksPerTile_MK{}),
|
make_shape(size<0>(layout), PPBlocksPerTile_MK{}),
|
||||||
|
@ -21,6 +21,8 @@
|
|||||||
|
|
||||||
#include "cutlass_extensions/cute_utils.cuh"
|
#include "cutlass_extensions/cute_utils.cuh"
|
||||||
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
|
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
#include "cutlass_extensions/torch_utils.hpp"
|
||||||
#include "machete_collective_builder.cuh"
|
#include "machete_collective_builder.cuh"
|
||||||
#include "machete_prepacked_layout.cuh"
|
#include "machete_prepacked_layout.cuh"
|
||||||
#include "machete_interleaving_utils.cuh"
|
#include "machete_interleaving_utils.cuh"
|
||||||
@ -37,27 +39,42 @@ using namespace cute;
|
|||||||
// W is quantized, in this situation or right-hand operand is quantized so
|
// W is quantized, in this situation or right-hand operand is quantized so
|
||||||
// we compute the transpose to move it to the left-hand side.
|
// we compute the transpose to move it to the left-hand side.
|
||||||
template <typename ElementA_, typename ElementB_, typename ElementD_,
|
template <typename ElementA_, typename ElementB_, typename ElementD_,
|
||||||
typename AccumulatorT, typename ScaleT, typename ZeroT,
|
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
|
||||||
class KernelSchedule, typename ScheduleConfig, bool with_C,
|
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
|
||||||
bool with_scales, bool with_zeropoints>
|
typename ScheduleConfig>
|
||||||
struct MacheteKernelTemplate {
|
struct MacheteKernelTemplate {
|
||||||
|
static constexpr bool with_C = false; // not ever used
|
||||||
|
static constexpr bool with_group_scales = !std::is_same_v<GroupScaleT, void>;
|
||||||
|
static constexpr bool with_group_zeropoints =
|
||||||
|
!std::is_same_v<GroupZeroT, void>;
|
||||||
|
static constexpr bool with_channel_scales =
|
||||||
|
!std::is_same_v<ChannelScaleT, void>;
|
||||||
|
static constexpr bool with_token_scales = !std::is_same_v<TokenScaleT, void>;
|
||||||
|
|
||||||
using MmaType = ElementA_;
|
using MmaType = ElementA_;
|
||||||
using ElementA = ElementA_;
|
using ElementA = ElementA_;
|
||||||
using ElementB = ElementB_;
|
using ElementB = ElementB_;
|
||||||
using ElementD = ElementD_;
|
using ElementD = ElementD_;
|
||||||
using ElementC = cute::conditional_t<with_C, ElementD, void>;
|
using ElementC = cute::conditional_t<with_C, ElementD, void>;
|
||||||
using ElementZ = ZeroT;
|
using ElementAccumulator = AccumulatorT;
|
||||||
using ElementS = ScaleT;
|
|
||||||
|
|
||||||
using ElementAccumulator =
|
|
||||||
AccumulatorT; // Element type for internal accumulation
|
|
||||||
using ElementCompute = AccumulatorT; // For Epilogue
|
using ElementCompute = AccumulatorT; // For Epilogue
|
||||||
|
// Use dummy values when we don't have scales or zeropoints
|
||||||
|
using ElementZGroup =
|
||||||
|
cute::conditional_t<with_group_zeropoints, GroupZeroT, MmaType>;
|
||||||
|
using ElementSGroup =
|
||||||
|
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
|
||||||
|
using ElementConvertGroup =
|
||||||
|
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
|
||||||
|
using ElementSChannel =
|
||||||
|
cute::conditional_t<with_channel_scales, ChannelScaleT, AccumulatorT>;
|
||||||
|
using ElementSToken =
|
||||||
|
cute::conditional_t<with_token_scales, TokenScaleT, AccumulatorT>;
|
||||||
|
|
||||||
using BTypeTuple = cute::conditional_t<
|
using BTypeTuple = cute::conditional_t<
|
||||||
with_scales,
|
with_group_scales,
|
||||||
cute::conditional_t<with_zeropoints,
|
cute::conditional_t<with_group_zeropoints,
|
||||||
cute::tuple<ElementB, ElementS, ElementZ>,
|
cute::tuple<ElementB, ElementSGroup, ElementZGroup>,
|
||||||
cute::tuple<ElementB, ElementS>>,
|
cute::tuple<ElementB, ElementSGroup>>,
|
||||||
ElementB>;
|
ElementB>;
|
||||||
|
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
@ -71,8 +88,8 @@ struct MacheteKernelTemplate {
|
|||||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||||
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
|
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
|
||||||
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||||
using StrideS = cutlass::detail::TagToStrideA_t<LayoutScale>;
|
using StrideSGroup = cutlass::detail::TagToStrideA_t<LayoutScale>;
|
||||||
using StrideZ = StrideS;
|
using StrideZGroup = StrideSGroup;
|
||||||
|
|
||||||
using LayoutA_Transpose =
|
using LayoutA_Transpose =
|
||||||
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||||
@ -85,8 +102,8 @@ struct MacheteKernelTemplate {
|
|||||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||||
|
|
||||||
using PrepackedLayoutB =
|
using PrepackedLayoutB =
|
||||||
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT,
|
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
|
||||||
LayoutA_Transpose, KernelSchedule>;
|
AccumulatorT, LayoutA_Transpose, KernelSchedule>;
|
||||||
|
|
||||||
static int constexpr TileShapeK =
|
static int constexpr TileShapeK =
|
||||||
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
|
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
|
||||||
@ -103,12 +120,42 @@ struct MacheteKernelTemplate {
|
|||||||
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
|
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
|
||||||
using TileScheduler = typename ScheduleConfig::TileScheduler;
|
using TileScheduler = typename ScheduleConfig::TileScheduler;
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
(!with_channel_scales && !with_token_scales) ||
|
||||||
|
((with_channel_scales && with_token_scales) &&
|
||||||
|
std::is_same_v<ElementSChannel, ElementSToken>),
|
||||||
|
"Currently token and channel scales (if present) must be the same type");
|
||||||
|
|
||||||
|
using EpilogueDescriptor =
|
||||||
|
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||||
|
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||||
|
ElementD, EpilogueSchedule>;
|
||||||
|
|
||||||
|
// Currently only supports float scales
|
||||||
|
using ChTokScalesEpilogue =
|
||||||
|
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
||||||
|
EpilogueDescriptor>;
|
||||||
|
static_assert((with_channel_scales || with_token_scales) ||
|
||||||
|
(std::is_same_v<ElementSChannel, float> &&
|
||||||
|
std::is_same_v<ElementSToken, float>),
|
||||||
|
"Currently token and channel scales (if present) must be float "
|
||||||
|
"(and if one is present the other must be too)");
|
||||||
|
|
||||||
|
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
|
||||||
|
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||||
|
|
||||||
|
using EVTCompute =
|
||||||
|
std::conditional_t<with_channel_scales || with_token_scales,
|
||||||
|
typename ChTokScalesEpilogue::EVTCompute,
|
||||||
|
StoreEpilogueCompute>;
|
||||||
|
|
||||||
|
// EVTCompute
|
||||||
using CollectiveEpilogue =
|
using CollectiveEpilogue =
|
||||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||||
ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose,
|
ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose,
|
||||||
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD,
|
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule,
|
||||||
EpilogueSchedule>::CollectiveOp;
|
EVTCompute>::CollectiveOp;
|
||||||
|
|
||||||
using CollectiveMainloop =
|
using CollectiveMainloop =
|
||||||
typename cutlass::gemm::collective::VLLMCollectiveBuilder<
|
typename cutlass::gemm::collective::VLLMCollectiveBuilder<
|
||||||
@ -131,26 +178,44 @@ struct MacheteKernelTemplate {
|
|||||||
using MainloopArguments = typename GemmKernel::MainloopArguments;
|
using MainloopArguments = typename GemmKernel::MainloopArguments;
|
||||||
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
|
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
|
||||||
|
|
||||||
template <typename ShapeA, typename ShapeC, typename ShapeD, typename ShapeS,
|
|
||||||
typename ShapeZ>
|
|
||||||
static Arguments create_arguments(
|
static Arguments create_arguments(
|
||||||
cudaStream_t stream,
|
cudaStream_t stream,
|
||||||
ElementA const* A_ptr, // A is an MxK matrix
|
torch::Tensor const& A, // MxK matrix
|
||||||
Layout<ShapeA, StrideA> const& layout_A,
|
torch::Tensor const& B, // KxN prepacked matrix
|
||||||
ElementB const* B_ptr, // B is an KxN prepacked matrix
|
torch::Tensor& D, // MxN matrix
|
||||||
ElementD* D_ptr, // D is an MxN matrix
|
c10::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
|
||||||
Layout<ShapeD, StrideD> const& layout_D,
|
c10::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
|
||||||
ElementC const* C_ptr, // C is an MxN matrix
|
c10::optional<int64_t> maybe_group_size,
|
||||||
std::optional<Layout<ShapeC, StrideC>> const& layout_C,
|
c10::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
|
||||||
ElementS const* S_ptr, // S is an scale_KxN matrix
|
c10::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
|
||||||
std::optional<Layout<ShapeS, StrideS>> const& layout_S,
|
{
|
||||||
ElementZ const* Z_ptr, // Z is an scale_KxN matrix
|
static_assert(!with_group_zeropoints || with_group_scales);
|
||||||
std::optional<Layout<ShapeZ, StrideZ>> const& layout_Z,
|
|
||||||
ElementCompute alpha, ElementCompute beta,
|
|
||||||
std::optional<int> maybe_group_size) {
|
|
||||||
static_assert(!with_zeropoints || with_scales);
|
|
||||||
|
|
||||||
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
|
int M = A.size(0), N = B.size(1), K = A.size(1);
|
||||||
|
TORCH_CHECK(D.size(0) == M && D.size(1) == N);
|
||||||
|
|
||||||
|
auto layout_A = make_cute_layout<StrideA>(A, "A");
|
||||||
|
auto layout_D = make_cute_layout<StrideD>(D, "D");
|
||||||
|
auto layout_S_group =
|
||||||
|
maybe_make_cute_layout<StrideSGroup>(maybe_g_scales, "group_scales");
|
||||||
|
auto layout_Z_group =
|
||||||
|
maybe_make_cute_layout<StrideZGroup>(maybe_g_zeros, "group_zeros");
|
||||||
|
int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0;
|
||||||
|
int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0;
|
||||||
|
|
||||||
|
auto unwrap = [](auto const& t) {
|
||||||
|
return t ? t->const_data_ptr() : nullptr;
|
||||||
|
};
|
||||||
|
auto A_ptr = static_cast<ElementA const*>(A.const_data_ptr());
|
||||||
|
auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
|
||||||
|
auto D_ptr = static_cast<ElementD*>(D.mutable_data_ptr());
|
||||||
|
auto S_group_ptr =
|
||||||
|
static_cast<ElementSGroup const*>(unwrap(maybe_g_scales));
|
||||||
|
auto Z_group_ptr = static_cast<ElementZGroup const*>(unwrap(maybe_g_zeros));
|
||||||
|
auto S_channel_ptr =
|
||||||
|
static_cast<ElementSChannel const*>(unwrap(maybe_ch_scales));
|
||||||
|
auto S_token_ptr =
|
||||||
|
static_cast<ElementSToken const*>(unwrap(maybe_tok_scales));
|
||||||
|
|
||||||
int const group_size =
|
int const group_size =
|
||||||
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
|
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
|
||||||
@ -159,26 +224,28 @@ struct MacheteKernelTemplate {
|
|||||||
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
|
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
|
||||||
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
|
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
|
||||||
|
|
||||||
if constexpr (with_C) {
|
if constexpr (with_group_scales) {
|
||||||
TORCH_CHECK(C_ptr && layout_C);
|
TORCH_CHECK(S_group_ptr && layout_S_group);
|
||||||
|
TORCH_CHECK((size<0>(*layout_S_group) == scale_k &&
|
||||||
|
size<1>(*layout_S_group) == N));
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(!C_ptr, "C not supported");
|
TORCH_CHECK(!S_group_ptr, "Scales not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (with_scales) {
|
if constexpr (with_group_zeropoints) {
|
||||||
TORCH_CHECK(S_ptr && layout_S);
|
TORCH_CHECK(Z_group_ptr && layout_Z_group);
|
||||||
TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N));
|
TORCH_CHECK((size<0>(*layout_Z_group) == scale_k &&
|
||||||
} else {
|
size<1>(*layout_Z_group) == N));
|
||||||
TORCH_CHECK(!S_ptr, "Scales not supported");
|
TORCH_CHECK(layout_S_group && *layout_Z_group == *layout_S_group,
|
||||||
}
|
|
||||||
|
|
||||||
if constexpr (with_zeropoints) {
|
|
||||||
TORCH_CHECK(Z_ptr && layout_Z);
|
|
||||||
TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N));
|
|
||||||
TORCH_CHECK(layout_S && *layout_Z == *layout_S,
|
|
||||||
"Scales and zeros must have the same layout");
|
"Scales and zeros must have the same layout");
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(!Z_ptr, "Zeropoints not supported");
|
TORCH_CHECK(!Z_group_ptr, "Zeropoints not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (with_channel_scales || with_token_scales) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
(maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) &&
|
||||||
|
(maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transpose A and D
|
// Transpose A and D
|
||||||
@ -186,24 +253,33 @@ struct MacheteKernelTemplate {
|
|||||||
// for B (which is At)
|
// for B (which is At)
|
||||||
auto stride_At = layout_A.stride();
|
auto stride_At = layout_A.stride();
|
||||||
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
|
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
|
||||||
auto stride_Ct = stride_Dt;
|
|
||||||
if (layout_C) {
|
|
||||||
stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride();
|
|
||||||
}
|
|
||||||
|
|
||||||
MainloopArguments mainloop_arguments{};
|
MainloopArguments mainloop_arguments{};
|
||||||
EpilogueArguments epilogue_arguments{
|
// {Accum, C, C_layout, D, D}
|
||||||
{alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt};
|
EpilogueArguments epilogue_arguments{};
|
||||||
|
|
||||||
if constexpr (with_scales && with_zeropoints) {
|
if constexpr (with_channel_scales || with_token_scales) {
|
||||||
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
|
epilogue_arguments =
|
||||||
mainloop_arguments =
|
EpilogueArguments{ChTokScalesEpilogue::prepare_args(
|
||||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
|
*maybe_ch_scales, *maybe_tok_scales),
|
||||||
S_ptr, stride_S, group_size, Z_ptr};
|
nullptr,
|
||||||
} else if constexpr (with_scales) {
|
{},
|
||||||
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
|
D_ptr,
|
||||||
|
stride_Dt};
|
||||||
|
} else {
|
||||||
|
epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt};
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (with_group_scales && with_group_zeropoints) {
|
||||||
|
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||||
mainloop_arguments = MainloopArguments{
|
mainloop_arguments = MainloopArguments{
|
||||||
B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size};
|
B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||||
|
S_group_ptr, stride_S_group, group_size, Z_group_ptr};
|
||||||
|
} else if constexpr (with_group_scales) {
|
||||||
|
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||||
|
mainloop_arguments =
|
||||||
|
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||||
|
S_group_ptr, stride_S_group, group_size};
|
||||||
} else {
|
} else {
|
||||||
mainloop_arguments =
|
mainloop_arguments =
|
||||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
|
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
|
||||||
|
@ -5,73 +5,61 @@
|
|||||||
|
|
||||||
#include "machete_mm_kernel.cuh"
|
#include "machete_mm_kernel.cuh"
|
||||||
#include "cutlass_extensions/torch_utils.hpp"
|
#include "cutlass_extensions/torch_utils.hpp"
|
||||||
|
#include "core/scalar_type.hpp"
|
||||||
|
|
||||||
namespace machete {
|
namespace machete {
|
||||||
|
|
||||||
struct PyTorchArguments {
|
struct MMArgs {
|
||||||
torch::Tensor const& A;
|
torch::Tensor const& A;
|
||||||
torch::Tensor const& B;
|
torch::Tensor const& B;
|
||||||
c10::optional<torch::Tensor> const& scales;
|
vllm::ScalarType const& b_type;
|
||||||
c10::optional<torch::Tensor> const& zeros;
|
c10::optional<at::ScalarType> const& maybe_out_type;
|
||||||
c10::optional<int64_t> group_size;
|
c10::optional<torch::Tensor> const& maybe_group_scales;
|
||||||
c10::optional<torch::Tensor> const& C;
|
c10::optional<torch::Tensor> const& maybe_group_zeros;
|
||||||
c10::optional<double> alpha;
|
c10::optional<int64_t> maybe_group_size;
|
||||||
c10::optional<double> beta;
|
c10::optional<torch::Tensor> const& maybe_channel_scales;
|
||||||
c10::optional<std::string> schedule;
|
c10::optional<torch::Tensor> const& maybe_token_scales;
|
||||||
|
c10::optional<std::string> maybe_schedule;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct SupportedSchedulesArgs {
|
||||||
|
at::ScalarType a_type;
|
||||||
|
vllm::ScalarType b_type;
|
||||||
|
c10::optional<at::ScalarType> maybe_group_scales_type;
|
||||||
|
c10::optional<at::ScalarType> maybe_group_zeros_type;
|
||||||
|
c10::optional<at::ScalarType> maybe_channel_scales_type;
|
||||||
|
c10::optional<at::ScalarType> maybe_token_scales_type;
|
||||||
|
c10::optional<at::ScalarType> maybe_out_type;
|
||||||
|
};
|
||||||
|
|
||||||
|
torch::Tensor mm_dispatch(MMArgs args);
|
||||||
|
|
||||||
|
std::vector<std::string> supported_schedules_dispatch(
|
||||||
|
SupportedSchedulesArgs args);
|
||||||
|
|
||||||
template <typename MacheteKernel>
|
template <typename MacheteKernel>
|
||||||
torch::Tensor run_impl(PyTorchArguments args) {
|
torch::Tensor run_impl(MMArgs args) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
|
||||||
|
|
||||||
auto device = args.A.device();
|
auto device = args.A.device();
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||||
|
|
||||||
using EleA = typename MacheteKernel::ElementA;
|
|
||||||
using EleB = typename MacheteKernel::ElementB;
|
|
||||||
using EleC = typename MacheteKernel::ElementC;
|
|
||||||
using EleD = typename MacheteKernel::ElementD;
|
|
||||||
using EleScale = typename MacheteKernel::ElementS;
|
|
||||||
using EleZero = typename MacheteKernel::ElementZ;
|
|
||||||
|
|
||||||
using StrideA = typename MacheteKernel::StrideA;
|
|
||||||
using StrideC = typename MacheteKernel::StrideC;
|
|
||||||
using StrideD = typename MacheteKernel::StrideD;
|
|
||||||
using StrideS = typename MacheteKernel::StrideS;
|
|
||||||
using StrideZ = typename MacheteKernel::StrideZ;
|
|
||||||
|
|
||||||
int M = args.A.size(0);
|
int M = args.A.size(0);
|
||||||
int N = args.B.size(1);
|
int N = args.B.size(1);
|
||||||
int K = args.A.size(1);
|
int K = args.A.size(1);
|
||||||
|
|
||||||
// Allocate output
|
// Allocate output
|
||||||
torch::Tensor D =
|
torch::Tensor D = torch::empty(
|
||||||
torch::empty({M, N}, torch::TensorOptions()
|
{M, N},
|
||||||
.dtype(equivalent_scalar_type_v<EleD>)
|
torch::TensorOptions()
|
||||||
.device(device));
|
.dtype(equivalent_scalar_type_v<typename MacheteKernel::ElementD>)
|
||||||
|
.device(device));
|
||||||
auto const &A = args.A, &B = args.B;
|
|
||||||
auto const &C = args.C, &scales = args.scales, &zeros = args.zeros;
|
|
||||||
|
|
||||||
auto layout_A = make_cute_layout<StrideA>(A, "A");
|
|
||||||
auto layout_D = make_cute_layout<StrideD>(D, "D");
|
|
||||||
auto layout_C = maybe_make_cute_layout<StrideC>(C, "C");
|
|
||||||
auto layout_S = maybe_make_cute_layout<StrideS>(scales, "scales");
|
|
||||||
auto layout_Z = maybe_make_cute_layout<StrideZ>(zeros, "zeros");
|
|
||||||
|
|
||||||
auto A_ptr = static_cast<EleA const*>(A.const_data_ptr());
|
|
||||||
auto B_ptr = static_cast<EleB const*>(B.const_data_ptr());
|
|
||||||
auto D_ptr = static_cast<EleD*>(D.mutable_data_ptr());
|
|
||||||
auto C_ptr = static_cast<EleC const*>(C ? C->const_data_ptr() : nullptr);
|
|
||||||
auto S_ptr =
|
|
||||||
static_cast<EleScale const*>(scales ? scales->const_data_ptr() : nullptr);
|
|
||||||
auto Z_ptr =
|
|
||||||
static_cast<EleZero const*>(zeros ? zeros->const_data_ptr() : nullptr);
|
|
||||||
|
|
||||||
auto arguments = MacheteKernel::create_arguments(
|
auto arguments = MacheteKernel::create_arguments(
|
||||||
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
|
stream, //
|
||||||
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
|
args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros,
|
||||||
args.group_size);
|
args.maybe_group_size, args.maybe_channel_scales,
|
||||||
|
args.maybe_token_scales);
|
||||||
TORCH_CHECK(MacheteKernel::can_implement(arguments),
|
TORCH_CHECK(MacheteKernel::can_implement(arguments),
|
||||||
"Machete kernel cannot be run with these arguments");
|
"Machete kernel cannot be run with these arguments");
|
||||||
|
|
||||||
@ -84,12 +72,4 @@ torch::Tensor run_impl(PyTorchArguments args) {
|
|||||||
return D;
|
return D;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ElementA, typename ElementB, typename ElementD = ElementA,
|
|
||||||
typename AccumulatorT = float, typename ScaleT = ElementA,
|
|
||||||
typename ZeroT = ElementA>
|
|
||||||
struct GemmDispatcher {
|
|
||||||
static torch::Tensor dispatch(PyTorchArguments args);
|
|
||||||
static std::vector<std::string> supported_schedules();
|
|
||||||
};
|
|
||||||
|
|
||||||
}; // namespace machete
|
}; // namespace machete
|
@ -6,31 +6,49 @@
|
|||||||
|
|
||||||
namespace machete {
|
namespace machete {
|
||||||
|
|
||||||
template <typename TileShapeNKL, typename ElementB, typename BInTensor,
|
template <int threads, typename PrepackedLayoutB, typename BInTensor,
|
||||||
typename BTiledOutTensor>
|
typename ElementB>
|
||||||
static __global__ void prepack_B_kernel(BInTensor B_in,
|
static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) {
|
||||||
BTiledOutTensor B_tiled_out) {
|
auto constexpr block_size =
|
||||||
auto tB_in = local_tile(B_in, TileShapeNKL{},
|
Int<size(typename PrepackedLayoutB::PPBlockShape_NK{})>{};
|
||||||
make_coord(blockIdx.x, blockIdx.y, blockIdx.z));
|
auto constexpr eles_per_thread = Int<block_size / threads>{};
|
||||||
auto tB_out = B_tiled_out(make_coord(_, _),
|
static_assert(block_size % threads == 0,
|
||||||
make_coord(blockIdx.x, blockIdx.y), blockIdx.z);
|
"block_size must be divisible by the number of threads");
|
||||||
|
|
||||||
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, ElementB>{},
|
// Which pre-packed are we responsible for
|
||||||
Layout<Shape<_4, _32>, Stride<_32, _1>>{},
|
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z);
|
||||||
Layout<Shape<_1, _2>>{});
|
auto tB_in = local_tile(
|
||||||
|
B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}),
|
||||||
|
blk_coord);
|
||||||
|
|
||||||
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
|
// Find the start offset in the output for this pre-packed block
|
||||||
|
auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in));
|
||||||
|
|
||||||
Tensor thr_tile_S = thr_copy.partition_S(tB_in);
|
// Tensor representing a 1:1 mapping to the output space in 1D
|
||||||
Tensor thr_tile_D = thr_copy.partition_D(tB_out);
|
auto tB_out_linear =
|
||||||
|
make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord),
|
||||||
|
make_layout(make_shape(block_size)));
|
||||||
|
// Mapping from output space (1D) to input space
|
||||||
|
auto tB_in_linear = make_tensor(
|
||||||
|
tB_in.data(),
|
||||||
|
tB_in.layout()
|
||||||
|
.compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset()))
|
||||||
|
.with_shape(make_shape(block_size)));
|
||||||
|
|
||||||
|
// Tile for this specific thread (could have used a TiledCopy but these work
|
||||||
|
// best with 2d layouts, this is a simple 1d layout so local_tile is enough,
|
||||||
|
// we are also not that concerned with performance for this kernel)
|
||||||
|
auto thr_tB_in_linear =
|
||||||
|
local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x);
|
||||||
|
auto thr_tB_out_linear =
|
||||||
|
local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x);
|
||||||
|
|
||||||
// Construct a register-backed Tensor with the same shape as each thread's
|
// Construct a register-backed Tensor with the same shape as each thread's
|
||||||
// partition
|
// partition
|
||||||
auto fragment = make_tensor<ElementB>(shape(thr_tile_D));
|
auto fragment = make_tensor<ElementB>(shape(thr_tB_in_linear));
|
||||||
|
|
||||||
// Copy from GMEM to RMEM and from RMEM to GMEM
|
copy(thr_tB_in_linear, fragment);
|
||||||
copy(tiled_copy, thr_tile_S, fragment);
|
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tB_out_linear);
|
||||||
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tile_D);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename PrepackedLayoutB, typename InLayout>
|
template <typename PrepackedLayoutB, typename InLayout>
|
||||||
@ -44,18 +62,15 @@ static void prepack_B_template(
|
|||||||
|
|
||||||
TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
|
TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
|
||||||
TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
|
TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
|
||||||
TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0);
|
|
||||||
|
|
||||||
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
|
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
|
||||||
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
|
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
|
||||||
auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{});
|
auto L_tiles = size<2>(B_layout);
|
||||||
|
|
||||||
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
|
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
|
||||||
auto B_tiled_out =
|
|
||||||
make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset);
|
|
||||||
|
|
||||||
prepack_B_kernel<TileShapeNKL, typename PrepackedLayoutB::ElementB>
|
prepack_B_kernel<128, PrepackedLayoutB>
|
||||||
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_tiled_out);
|
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_out_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // namespace machete
|
}; // namespace machete
|
@ -2,9 +2,17 @@
|
|||||||
|
|
||||||
#include "machete_prepack_kernel.cuh"
|
#include "machete_prepack_kernel.cuh"
|
||||||
#include "cutlass_extensions/torch_utils.hpp"
|
#include "cutlass_extensions/torch_utils.hpp"
|
||||||
|
#include "core/scalar_type.hpp"
|
||||||
|
|
||||||
namespace machete {
|
namespace machete {
|
||||||
|
|
||||||
|
struct PrepackBArgs {
|
||||||
|
torch::Tensor const& B;
|
||||||
|
at::ScalarType a_type;
|
||||||
|
vllm::ScalarType b_type;
|
||||||
|
c10::optional<at::ScalarType> maybe_group_scales_type;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename PrepackedLayoutB>
|
template <typename PrepackedLayoutB>
|
||||||
torch::Tensor prepack_impl(torch::Tensor const B) {
|
torch::Tensor prepack_impl(torch::Tensor const B) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
|
||||||
@ -61,11 +69,6 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
|
|||||||
return D;
|
return D;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ElementA, typename ElementB, typename ElementD,
|
torch::Tensor prepack_B_dispatch(PrepackBArgs args);
|
||||||
typename AccumulatorT = float, typename ScaleT = cutlass::half_t,
|
|
||||||
typename ZeroT = cutlass::half_t>
|
|
||||||
struct PrepackBDispatcher {
|
|
||||||
static torch::Tensor dispatch(torch::Tensor B);
|
|
||||||
};
|
|
||||||
|
|
||||||
}; // namespace machete
|
}; // namespace machete
|
@ -41,7 +41,7 @@ struct IlvBlkLayoutAuto {};
|
|||||||
// The contract here is that the `TiledMma` determined below matches the one
|
// The contract here is that the `TiledMma` determined below matches the one
|
||||||
// ultimately used in the kernel. (this is also why the other element types are
|
// ultimately used in the kernel. (this is also why the other element types are
|
||||||
// required along with the kernel schedule)
|
// required along with the kernel schedule)
|
||||||
template <typename ElementA_, typename ElementB_, typename ElementD_,
|
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
|
||||||
typename AccumulatorT, class LayoutB, class KernelSchedule,
|
typename AccumulatorT, class LayoutB, class KernelSchedule,
|
||||||
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
|
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
|
||||||
// clang-format on
|
// clang-format on
|
||||||
@ -49,20 +49,27 @@ struct PrepackedLayoutBTemplate {
|
|||||||
using MmaType = ElementA_;
|
using MmaType = ElementA_;
|
||||||
using ElementA = ElementA_;
|
using ElementA = ElementA_;
|
||||||
using ElementB = ElementB_;
|
using ElementB = ElementB_;
|
||||||
using ElementD = ElementD_;
|
using ElementAccumulator = AccumulatorT;
|
||||||
using ElementAccumulator =
|
|
||||||
AccumulatorT; // Element type for internal accumulation
|
|
||||||
using ElementMma = MmaType;
|
using ElementMma = MmaType;
|
||||||
|
|
||||||
// Only use interleaved layouts for subbyte weights, prmt instructions makes
|
// Interleave for 4bit bit types when we are not upconverting to fp8 or int8,
|
||||||
// non-interleaved layouts for 8bit+ weights efficient enough we don't need
|
// in those cases case we use a LUT using prmt instructions to upconvert and
|
||||||
// iterleaved layouts
|
// is more efficient if the data is not interleaved For 8bit+ prmt
|
||||||
|
// instructions makes non-interleaved layouts efficient enough we don't need
|
||||||
|
// iterleaved layouts (and can reuse more of the existing cutlass converts)
|
||||||
|
static constexpr bool should_interleave =
|
||||||
|
sizeof_bits_v<ElementB> <= 4 &&
|
||||||
|
!std::is_same_v<ElementConvert_, cutlass::float_e4m3_t> &&
|
||||||
|
!std::is_same_v<ElementConvert_, int8_t>;
|
||||||
|
|
||||||
|
// Only use interleaved layouts for subbyte weights,
|
||||||
using IlvdBlkLayout = std::conditional_t<
|
using IlvdBlkLayout = std::conditional_t<
|
||||||
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
|
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
|
||||||
std::conditional_t<sizeof_bits_v<ElementB> <= 4,
|
std::conditional_t<
|
||||||
decltype(get_interleaved_blk_layout<
|
should_interleave,
|
||||||
ElementB, sizeof_bits_v<ElementA>, 32>()),
|
decltype(get_interleaved_blk_layout<
|
||||||
void>,
|
ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
|
||||||
|
void>,
|
||||||
IlvBlkLayout_>;
|
IlvBlkLayout_>;
|
||||||
|
|
||||||
// TODO (LucasWilkinson): compare the performance for other sizes
|
// TODO (LucasWilkinson): compare the performance for other sizes
|
||||||
@ -135,7 +142,8 @@ struct PrepackedLayoutBTemplate {
|
|||||||
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
|
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
|
||||||
auto frgV = get<1, 0>(layout_no_interleave);
|
auto frgV = get<1, 0>(layout_no_interleave);
|
||||||
auto ilvdBlk = IlvdBlkLayout{};
|
auto ilvdBlk = IlvdBlkLayout{};
|
||||||
static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4");
|
static_assert(size(frgV) % size(ilvdBlk) == 0,
|
||||||
|
"FrgV must be divisible by size(ilvdBlk)");
|
||||||
auto ilvd_FrgV = make_layout(
|
auto ilvd_FrgV = make_layout(
|
||||||
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
|
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
|
||||||
make_stride(stride(ilvdBlk), size(ilvdBlk)));
|
make_stride(stride(ilvdBlk), size(ilvdBlk)));
|
||||||
@ -175,6 +183,15 @@ struct PrepackedLayoutBTemplate {
|
|||||||
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||||
|
template <typename Shape_NKL>
|
||||||
|
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy(
|
||||||
|
Shape_NKL shape_mkl) {
|
||||||
|
auto layout = TVbNbKL_to_offset(shape_mkl);
|
||||||
|
return make_layout(coalesce(get<0>(layout)), get<1>(layout),
|
||||||
|
get<2>(layout));
|
||||||
|
}
|
||||||
|
|
||||||
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
|
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
|
||||||
template <typename Shape_NKL>
|
template <typename Shape_NKL>
|
||||||
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
|
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
|
||||||
@ -197,6 +214,19 @@ struct PrepackedLayoutBTemplate {
|
|||||||
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// (BlocksN, BlocksK, L) -> (storage_idx)
|
||||||
|
template <typename Shape_NKL>
|
||||||
|
CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
|
||||||
|
// (BlocksN, BlocksK, L)
|
||||||
|
auto blocks_shape =
|
||||||
|
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||||
|
[](auto x, auto y) { return x / y; });
|
||||||
|
auto stride = size(PPBlockShape_NK{});
|
||||||
|
|
||||||
|
// (BlocksN, BlocksK, L) -> (storage_idx)
|
||||||
|
return make_layout(blocks_shape, compact_col_major(blocks_shape, stride));
|
||||||
|
}
|
||||||
|
|
||||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||||
template <class Shape_NKL>
|
template <class Shape_NKL>
|
||||||
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
|
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
|
||||||
|
@ -8,89 +8,61 @@ namespace machete {
|
|||||||
|
|
||||||
using namespace vllm;
|
using namespace vllm;
|
||||||
|
|
||||||
//
|
std::vector<std::string> supported_schedules(
|
||||||
// Utils (type dispatching)
|
at::ScalarType a_type, int64_t b_type_id,
|
||||||
//
|
c10::optional<at::ScalarType> maybe_group_scales_type,
|
||||||
|
c10::optional<at::ScalarType> maybe_group_zeros_type,
|
||||||
template <typename Fn>
|
c10::optional<at::ScalarType> maybe_channel_scales_type,
|
||||||
static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
|
c10::optional<at::ScalarType> maybe_token_scales_type,
|
||||||
if (type == vllm::kU4) {
|
c10::optional<at::ScalarType> maybe_out_type) {
|
||||||
return fn(cutlass::uint4b_t{});
|
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||||
} else if (type == vllm::kU8) {
|
return supported_schedules_dispatch({
|
||||||
return fn(cutlass::uint8_t{});
|
.a_type = a_type,
|
||||||
} else if (type == vllm::kU4B8) {
|
.b_type = b_type,
|
||||||
return fn(cutlass::vllm_uint4b8_t{});
|
.maybe_group_scales_type = maybe_group_scales_type,
|
||||||
} else if (type == vllm::kU8B128) {
|
.maybe_group_zeros_type = maybe_group_zeros_type,
|
||||||
return fn(cutlass::vllm_uint8b128_t{});
|
.maybe_channel_scales_type = maybe_channel_scales_type,
|
||||||
} else {
|
.maybe_token_scales_type = maybe_token_scales_type,
|
||||||
TORCH_CHECK(false, "Unsupported type ", type.str());
|
.maybe_out_type = maybe_out_type,
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
|
torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
|
||||||
AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)
|
int64_t b_type_id,
|
||||||
|
c10::optional<at::ScalarType> const& maybe_out_type,
|
||||||
#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
|
c10::optional<torch::Tensor> const& maybe_group_scales,
|
||||||
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
c10::optional<torch::Tensor> const& maybe_group_zeros,
|
||||||
AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))
|
c10::optional<int64_t> maybe_group_size,
|
||||||
|
c10::optional<torch::Tensor> const& maybe_channel_scales,
|
||||||
//
|
c10::optional<torch::Tensor> const& maybe_token_scales,
|
||||||
// Interface
|
c10::optional<std::string> maybe_schedule) {
|
||||||
//
|
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||||
|
return mm_dispatch({.A = A,
|
||||||
std::vector<std::string> supported_schedules(ScalarTypeId const btype_id) {
|
.B = B,
|
||||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
.b_type = b_type,
|
||||||
vllm::ScalarType b_type = ScalarType::from_id(btype_id);
|
.maybe_out_type = maybe_out_type,
|
||||||
return scalar_type_dispatch(b_type, [&](auto BType) {
|
.maybe_group_scales = maybe_group_scales,
|
||||||
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
|
.maybe_group_zeros = maybe_group_zeros,
|
||||||
});
|
.maybe_group_size = maybe_group_size,
|
||||||
#else
|
.maybe_channel_scales = maybe_channel_scales,
|
||||||
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
.maybe_token_scales = maybe_token_scales,
|
||||||
#endif
|
.maybe_schedule = maybe_schedule});
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
torch::Tensor prepack_B(
|
||||||
ScalarTypeId const btype_id,
|
torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id,
|
||||||
c10::optional<torch::Tensor> const& scales,
|
c10::optional<at::ScalarType> const& maybe_group_scales_type) {
|
||||||
c10::optional<torch::Tensor> const& zeros,
|
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||||
c10::optional<int64_t> group_size,
|
return prepack_B_dispatch(
|
||||||
c10::optional<torch::Tensor> const& C,
|
{.B = B,
|
||||||
c10::optional<double> alpha, c10::optional<double> beta,
|
.a_type = a_type,
|
||||||
c10::optional<std::string> schedule) {
|
.b_type = b_type,
|
||||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
.maybe_group_scales_type = maybe_group_scales_type});
|
||||||
ScalarType const btype = ScalarType::from_id(btype_id);
|
|
||||||
auto args = PyTorchArguments{.A = A,
|
|
||||||
.B = B,
|
|
||||||
.scales = scales,
|
|
||||||
.zeros = zeros,
|
|
||||||
.group_size = group_size,
|
|
||||||
.C = C,
|
|
||||||
.alpha = alpha,
|
|
||||||
.beta = beta,
|
|
||||||
.schedule = schedule};
|
|
||||||
|
|
||||||
return scalar_type_dispatch(btype, [&](auto BType) {
|
|
||||||
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
|
|
||||||
A.scalar_type(), "machete_gemm", [&] {
|
|
||||||
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
|
|
||||||
return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) {
|
|
||||||
ScalarType const btype = ScalarType::from_id(btype_id);
|
|
||||||
return scalar_type_dispatch(btype, [&](auto BType) {
|
|
||||||
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
m.impl("machete_prepack_B", &prepack_B);
|
m.impl("machete_prepack_B", &prepack_B);
|
||||||
m.impl("machete_gemm", &gemm);
|
m.impl("machete_mm", &mm);
|
||||||
}
|
}
|
||||||
|
|
||||||
// use CatchAll since supported_schedules has no tensor arguments
|
// use CatchAll since supported_schedules has no tensor arguments
|
||||||
|
@ -203,13 +203,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
// conditionally compiled so impl in source file
|
// conditionally compiled so impl in source file
|
||||||
|
|
||||||
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
||||||
ops.def("machete_supported_schedules(int btype) -> str[]");
|
|
||||||
ops.def(
|
ops.def(
|
||||||
"machete_gemm(Tensor A, Tensor B, int btype, "
|
"machete_supported_schedules("
|
||||||
" Tensor? scales, Tensor? zeros, int? group_size, "
|
" ScalarType a_type,"
|
||||||
" Tensor? C, float? alpha, float? beta, str? schedule)"
|
" int b_type,"
|
||||||
"-> Tensor");
|
" ScalarType? maybe_group_scales_type,"
|
||||||
ops.def("machete_prepack_B(Tensor B, int btype) -> Tensor");
|
" ScalarType? maybe_group_zeros_type,"
|
||||||
|
" ScalarType? maybe_channel_scales_type,"
|
||||||
|
" ScalarType? maybe_token_scales_type,"
|
||||||
|
" ScalarType? maybe_out_type"
|
||||||
|
") -> str[]");
|
||||||
|
ops.def(
|
||||||
|
"machete_mm("
|
||||||
|
" Tensor A,"
|
||||||
|
" Tensor B,"
|
||||||
|
" int b_type,"
|
||||||
|
" ScalarType? out_type,"
|
||||||
|
" Tensor? group_scales,"
|
||||||
|
" Tensor? group_zeros,"
|
||||||
|
" int? group_size,"
|
||||||
|
" Tensor? channel_scales,"
|
||||||
|
" Tensor? token_scales,"
|
||||||
|
" str? schedule"
|
||||||
|
") -> Tensor");
|
||||||
|
ops.def(
|
||||||
|
"machete_prepack_B("
|
||||||
|
" Tensor B,"
|
||||||
|
" ScalarType a_type,"
|
||||||
|
" int b_type,"
|
||||||
|
" ScalarType? group_scales_type"
|
||||||
|
") -> Tensor");
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||||
|
@ -1,284 +0,0 @@
|
|||||||
"""Tests for the machete kernel.
|
|
||||||
|
|
||||||
Run `pytest tests/kernels/test_machete_gemm.py`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from tests.kernels.utils import opcheck
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
||||||
pack_rows, quantize_weights)
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.scalar_type import ScalarType, scalar_types
|
|
||||||
|
|
||||||
CUDA_DEVICES = [
|
|
||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
|
||||||
]
|
|
||||||
|
|
||||||
MNK_SHAPES = [
|
|
||||||
(1, 128, 128),
|
|
||||||
(1, 512, 1024),
|
|
||||||
(1, 4096, 4096),
|
|
||||||
(1, 8192, 28672),
|
|
||||||
(13, 8192, 4096),
|
|
||||||
(26, 4096, 8192),
|
|
||||||
(64, 4096, 4096),
|
|
||||||
(64, 8192, 28672),
|
|
||||||
(257, 128, 4096),
|
|
||||||
(257, 4224, 4160),
|
|
||||||
(257, 4096, 4096),
|
|
||||||
(1024, 4096, 8192),
|
|
||||||
(1024, 8192, 4096),
|
|
||||||
]
|
|
||||||
|
|
||||||
ACT_TYPES = [torch.float16, torch.bfloat16]
|
|
||||||
WTYPE_ZEROPOINTS = [
|
|
||||||
# GPTQ style
|
|
||||||
(scalar_types.uint4b8, False),
|
|
||||||
(scalar_types.uint8b128, False),
|
|
||||||
# AWQ style
|
|
||||||
(scalar_types.uint4, True),
|
|
||||||
(scalar_types.uint8, True),
|
|
||||||
]
|
|
||||||
|
|
||||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
|
||||||
# unit tests to a common utility function. Currently the use of
|
|
||||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
|
||||||
# an assumption which is breaking down as quantizations methods can have
|
|
||||||
# have kernels and some kernels support multiple quantization methods.
|
|
||||||
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
|
|
||||||
|
|
||||||
|
|
||||||
def rand_data(shape, dtype=torch.float16):
|
|
||||||
return 10 * (torch.rand(shape, dtype=dtype, device="cuda") - 0.3)
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
|
|
||||||
return zps if zps is None else -1 * s * (zps.to(s.dtype))
|
|
||||||
|
|
||||||
|
|
||||||
def machete_quantize_and_pack(w: torch.Tensor,
|
|
||||||
wtype: ScalarType,
|
|
||||||
group_size: int,
|
|
||||||
zero_points: bool = False):
|
|
||||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
|
||||||
|
|
||||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
|
||||||
w,
|
|
||||||
wtype,
|
|
||||||
group_size,
|
|
||||||
zero_points=zero_points,
|
|
||||||
# to match how the kernel applies zps
|
|
||||||
ref_zero_points_after_scales=True)
|
|
||||||
|
|
||||||
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
|
||||||
w_q = w_q.t().contiguous().t() # convert to col major
|
|
||||||
w_q_machete = ops.machete_prepack_B(w_q, wtype)
|
|
||||||
|
|
||||||
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id))
|
|
||||||
|
|
||||||
return w_ref, w_q_machete, w_s, w_zp
|
|
||||||
|
|
||||||
|
|
||||||
def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor,
|
|
||||||
wtype: ScalarType, group_size: int,
|
|
||||||
zero_points: bool):
|
|
||||||
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
|
||||||
b, wtype, group_size, zero_points)
|
|
||||||
|
|
||||||
output_ref = torch.matmul(a, w_ref)
|
|
||||||
|
|
||||||
output = ops.machete_gemm(
|
|
||||||
a=a,
|
|
||||||
b_q=w_q_packed,
|
|
||||||
b_type=wtype,
|
|
||||||
b_scales=w_s,
|
|
||||||
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
|
|
||||||
b_group_size=group_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
|
||||||
# Relax atol when we have zeropoints since the way machete applies
|
|
||||||
# zeropoints (after scales) causes noise around 0
|
|
||||||
atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1)
|
|
||||||
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
|
||||||
reason="Machete is not supported on this GPU type.")
|
|
||||||
@pytest.mark.parametrize("shape",
|
|
||||||
MNK_SHAPES,
|
|
||||||
ids=lambda x: "x".join(str(v) for v in x))
|
|
||||||
@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x))
|
|
||||||
@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS)
|
|
||||||
@pytest.mark.parametrize("group_size", [128, None])
|
|
||||||
def test_machete_all_schedules(shape, atype: torch.dtype,
|
|
||||||
wtype_zeropoints: Tuple[ScalarType, bool],
|
|
||||||
group_size: Optional[int]):
|
|
||||||
m, n, k = shape
|
|
||||||
wtype, zero_points = wtype_zeropoints
|
|
||||||
|
|
||||||
if group_size is not None and k % group_size != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"MNK = {m} {n} {k}")
|
|
||||||
|
|
||||||
# Normalize group_size
|
|
||||||
if group_size is None:
|
|
||||||
group_size = k
|
|
||||||
assert group_size <= k
|
|
||||||
|
|
||||||
a = rand_data((m, k), atype)
|
|
||||||
w = rand_data((k, n), atype)
|
|
||||||
|
|
||||||
w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack(
|
|
||||||
w, wtype, group_size, zero_points)
|
|
||||||
|
|
||||||
output_ref = torch.matmul(a, w_ref)
|
|
||||||
|
|
||||||
for schedule in ops.machete_supported_schedules(wtype):
|
|
||||||
print(f"Testing schedule {schedule}")
|
|
||||||
output = ops.machete_gemm(
|
|
||||||
a,
|
|
||||||
b_q=w_q_machete,
|
|
||||||
b_type=wtype,
|
|
||||||
b_scales=w_s,
|
|
||||||
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
|
|
||||||
b_group_size=group_size,
|
|
||||||
schedule=schedule,
|
|
||||||
)
|
|
||||||
|
|
||||||
opcheck(
|
|
||||||
torch.ops._C.machete_gemm,
|
|
||||||
(a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints(
|
|
||||||
w_zp, w_s), group_size, None, None, None, schedule))
|
|
||||||
|
|
||||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
|
||||||
# Relax atol when we have zeropoints since the way machete applies
|
|
||||||
# zeropoints (after scales) causes noise around 0
|
|
||||||
atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
|
|
||||||
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\
|
|
||||||
f"Schedule failed {schedule}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
|
||||||
reason="Machete is not supported on this GPU type.")
|
|
||||||
@pytest.mark.parametrize("shape",
|
|
||||||
MNK_SHAPES,
|
|
||||||
ids=lambda x: "x".join(str(v) for v in x))
|
|
||||||
@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x))
|
|
||||||
@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS)
|
|
||||||
@pytest.mark.parametrize("group_size", [128, None])
|
|
||||||
def test_machete_heuristic(shape, atype: torch.dtype,
|
|
||||||
wtype_zeropoints: Tuple[ScalarType, bool],
|
|
||||||
group_size: Optional[int]):
|
|
||||||
m, n, k = shape
|
|
||||||
wtype, zero_points = wtype_zeropoints
|
|
||||||
|
|
||||||
if group_size is not None and k % group_size != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Normalize group_size
|
|
||||||
if group_size is None:
|
|
||||||
group_size = k
|
|
||||||
assert group_size <= k
|
|
||||||
|
|
||||||
a = rand_data((m, k), atype)
|
|
||||||
b = rand_data((k, n), atype)
|
|
||||||
|
|
||||||
machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
|
|
||||||
|
|
||||||
|
|
||||||
# Test working on other devices
|
|
||||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
|
||||||
reason="Machete is not supported on this GPU type.")
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
||||||
def test_machete_devices(device: str):
|
|
||||||
m, n, k = 512, 4096, 4096
|
|
||||||
wtype = scalar_types.uint4b8
|
|
||||||
group_size = 128
|
|
||||||
zero_points = False
|
|
||||||
|
|
||||||
print(f"MNK = {m} {n} {k}, device = {device}")
|
|
||||||
|
|
||||||
a = rand_data((m, k), torch.float16).to(device)
|
|
||||||
b = rand_data((k, n), torch.float16).to(device)
|
|
||||||
|
|
||||||
machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
|
|
||||||
|
|
||||||
|
|
||||||
# Test working with a subset of A and B
|
|
||||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
|
||||||
reason="Machete is not supported on this GPU type.")
|
|
||||||
def test_machete_subset():
|
|
||||||
big_m, big_n, big_k = 1024, 1024, 1024
|
|
||||||
m, n, k = 512, 512, 512
|
|
||||||
wtype = scalar_types.uint4b8
|
|
||||||
group_size = 128
|
|
||||||
zero_points = False
|
|
||||||
|
|
||||||
whole_a = rand_data((big_m, big_k), torch.float16)
|
|
||||||
whole_b = rand_data((big_k, big_n), torch.float16)
|
|
||||||
|
|
||||||
a = whole_a[0:m, 0:k]
|
|
||||||
b = whole_b[0:k, 0:n]
|
|
||||||
|
|
||||||
machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
|
|
||||||
|
|
||||||
|
|
||||||
# Test to make sure cuda graphs work
|
|
||||||
class MacheteLayer(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
def forward(self, a):
|
|
||||||
return ops.machete_gemm(**self.kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
|
||||||
reason="Machete is not supported on this GPU type.")
|
|
||||||
def test_machete_cuda_graph():
|
|
||||||
m, n, k = 512, 4096, 4096
|
|
||||||
|
|
||||||
a = rand_data((m, k), torch.float16)
|
|
||||||
b = rand_data((k, n), torch.float16)
|
|
||||||
wtype = scalar_types.uint4b8
|
|
||||||
group_size = 128
|
|
||||||
zero_points = False
|
|
||||||
|
|
||||||
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
|
||||||
b, wtype, group_size, zero_points)
|
|
||||||
|
|
||||||
# Construct a trivial model with a single layer that calls a machete kernel
|
|
||||||
model = MacheteLayer(
|
|
||||||
a=a,
|
|
||||||
b_q=w_q_packed,
|
|
||||||
b_type=wtype,
|
|
||||||
b_scales=w_s,
|
|
||||||
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
|
|
||||||
b_group_size=group_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_ref = torch.matmul(a, w_ref)
|
|
||||||
|
|
||||||
# Run the model with a cuda graph
|
|
||||||
stream = torch.cuda.Stream()
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
g = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(g):
|
|
||||||
output = model(a)
|
|
||||||
output.zero_()
|
|
||||||
g.replay()
|
|
||||||
|
|
||||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
|
||||||
# Relax atol when we have zeropoints since the way machete applies
|
|
||||||
# zeropoints (after scales) causes noise around 0
|
|
||||||
atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
|
|
||||||
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)
|
|
406
tests/kernels/test_machete_mm.py
Normal file
406
tests/kernels/test_machete_mm.py
Normal file
@ -0,0 +1,406 @@
|
|||||||
|
"""Tests for the machete kernel.
|
||||||
|
|
||||||
|
Run `pytest tests/kernels/test_machete_mm.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass, fields
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.utils import opcheck
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
pack_rows, quantize_weights)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
|
||||||
|
CUDA_DEVICES = [
|
||||||
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||||
|
# unit tests to a common utility function. Currently the use of
|
||||||
|
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||||
|
# an assumption which is breaking down as quantizations methods can have
|
||||||
|
# have kernels and some kernels support multiple quantization methods.
|
||||||
|
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||||
|
|
||||||
|
MNK_SHAPES = [
|
||||||
|
(1, 128, 128),
|
||||||
|
(1, 512, 1024),
|
||||||
|
(1, 4096, 4096),
|
||||||
|
(1, 8192, 28672),
|
||||||
|
(13, 8192, 4096),
|
||||||
|
(26, 4096, 8192),
|
||||||
|
(64, 4096, 4096),
|
||||||
|
(64, 8192, 28672),
|
||||||
|
(257, 128, 4096),
|
||||||
|
(257, 4224, 4160),
|
||||||
|
(257, 4096, 4096),
|
||||||
|
(1024, 4096, 8192),
|
||||||
|
(1024, 8192, 4096),
|
||||||
|
]
|
||||||
|
|
||||||
|
GROUP_SIZES_TO_TEST: List[Optional[int]] = [128, -1]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TypeConfig:
|
||||||
|
act_type: torch.dtype
|
||||||
|
weight_type: ScalarType
|
||||||
|
output_type: Optional[torch.dtype]
|
||||||
|
group_scale_type: Optional[torch.dtype]
|
||||||
|
group_zero_type: Optional[torch.dtype]
|
||||||
|
channel_scale_type: Optional[torch.dtype]
|
||||||
|
token_scale_type: Optional[torch.dtype]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Tensors:
|
||||||
|
w_ref: torch.Tensor
|
||||||
|
a_ref: torch.Tensor
|
||||||
|
a: torch.Tensor
|
||||||
|
w_q: torch.Tensor
|
||||||
|
w_g_s: Optional[torch.Tensor]
|
||||||
|
w_g_zp: Optional[torch.Tensor]
|
||||||
|
w_ch_s: Optional[torch.Tensor]
|
||||||
|
w_tok_s: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
||||||
|
# Ch Scales Type, Tok Scales Type)
|
||||||
|
# NOTE: None "Scale Type" means the act type is floating point
|
||||||
|
# None "Output Type" means the output type is the same as the act type
|
||||||
|
TestTypeTuple = Tuple[List[torch.dtype], ScalarType, Optional[torch.dtype],
|
||||||
|
Optional[torch.dtype], bool]
|
||||||
|
TEST_TYPES = [
|
||||||
|
# GPTQ style
|
||||||
|
*(TypeConfig(act_type=a_type,
|
||||||
|
weight_type=w_type,
|
||||||
|
output_type=None,
|
||||||
|
group_scale_type=a_type,
|
||||||
|
group_zero_type=None,
|
||||||
|
channel_scale_type=None,
|
||||||
|
token_scale_type=None)
|
||||||
|
for w_type in [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||||
|
for a_type in [torch.float16, torch.bfloat16]),
|
||||||
|
# AWQ style
|
||||||
|
*(TypeConfig(act_type=a_type,
|
||||||
|
weight_type=w_type,
|
||||||
|
output_type=None,
|
||||||
|
group_scale_type=a_type,
|
||||||
|
group_zero_type=a_type,
|
||||||
|
channel_scale_type=None,
|
||||||
|
token_scale_type=None)
|
||||||
|
for w_type in [scalar_types.uint4, scalar_types.uint8]
|
||||||
|
for a_type in [torch.float16, torch.bfloat16]),
|
||||||
|
# QQQ style
|
||||||
|
*(TypeConfig(act_type=torch.int8,
|
||||||
|
weight_type=scalar_types.uint4b8,
|
||||||
|
output_type=torch.float16,
|
||||||
|
group_scale_type=group_scale_type,
|
||||||
|
group_zero_type=None,
|
||||||
|
channel_scale_type=torch.float,
|
||||||
|
token_scale_type=torch.float)
|
||||||
|
for group_scale_type in [None, torch.float16]),
|
||||||
|
*(TypeConfig(act_type=torch.float8_e4m3fn,
|
||||||
|
weight_type=scalar_types.uint4b8,
|
||||||
|
output_type=torch.float16,
|
||||||
|
group_scale_type=group_scale_type,
|
||||||
|
group_zero_type=None,
|
||||||
|
channel_scale_type=torch.float,
|
||||||
|
token_scale_type=torch.float)
|
||||||
|
for group_scale_type in [None, torch.float16]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||||
|
# unit tests to a common utility function. Currently the use of
|
||||||
|
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||||
|
# an assumption which is breaking down as quantizations methods can have
|
||||||
|
# have kernels and some kernels support multiple quantization methods.
|
||||||
|
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
|
||||||
|
|
||||||
|
|
||||||
|
def rand_data(shape, dtype=torch.float16, scale=1, offset=0):
|
||||||
|
if dtype.is_floating_point:
|
||||||
|
return (scale * torch.rand(shape, device="cuda") - offset).to(dtype)
|
||||||
|
else:
|
||||||
|
return torch.randint(-8, 7, shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
|
||||||
|
return zps if zps is None else -1 * s * (zps.to(s.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def group_size_valid(shape: Tuple[int, int, int],
|
||||||
|
group_size: Optional[int]) -> bool:
|
||||||
|
return group_size is None or group_size == -1 or group_size % shape[2] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def machete_quantize_and_pack(atype: torch.dtype,
|
||||||
|
w: torch.Tensor,
|
||||||
|
wtype: ScalarType,
|
||||||
|
stype: Optional[torch.dtype],
|
||||||
|
group_size: Optional[int],
|
||||||
|
zero_points: bool = False):
|
||||||
|
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||||
|
|
||||||
|
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||||
|
w,
|
||||||
|
wtype,
|
||||||
|
group_size=group_size,
|
||||||
|
zero_points=zero_points,
|
||||||
|
# to match how the kernel applies zps
|
||||||
|
ref_zero_points_after_scales=True)
|
||||||
|
|
||||||
|
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||||
|
w_q = w_q.t().contiguous().t() # convert to col major
|
||||||
|
|
||||||
|
w_q_machete = ops.machete_prepack_B(w_q, atype, wtype, stype)
|
||||||
|
opcheck(torch.ops._C.machete_prepack_B, (w_q, atype, wtype.id, stype))
|
||||||
|
|
||||||
|
return w_ref, w_q_machete, w_s, w_zp
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_tensors(shape: Tuple[int, int, int],
|
||||||
|
types: TypeConfig,
|
||||||
|
group_size: Optional[int],
|
||||||
|
subset_stride_factor: Optional[int] = None) -> Tensors:
|
||||||
|
m, n, k = shape
|
||||||
|
factor = subset_stride_factor or 1
|
||||||
|
|
||||||
|
print("create_test_tensors, shape:", shape, "types:", types, "group_size:",
|
||||||
|
group_size)
|
||||||
|
|
||||||
|
a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2)
|
||||||
|
w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1)
|
||||||
|
|
||||||
|
if factor > 1:
|
||||||
|
a = a[0:m, 0:k]
|
||||||
|
w = w[0:k, 0:n]
|
||||||
|
|
||||||
|
if types.group_scale_type is not None:
|
||||||
|
w = w.to(types.group_scale_type)
|
||||||
|
if w.dtype.itemsize == 1:
|
||||||
|
w = w.to(torch.float16)
|
||||||
|
|
||||||
|
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
||||||
|
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
|
||||||
|
types.group_zero_type is not None)
|
||||||
|
|
||||||
|
if not a.dtype.is_floating_point:
|
||||||
|
aiinfo = torch.iinfo(a.dtype)
|
||||||
|
w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max)
|
||||||
|
|
||||||
|
a_ref = a.to(torch.float32)
|
||||||
|
w_ref = w_ref.to(torch.float32)
|
||||||
|
|
||||||
|
w_ch_s = None if types.channel_scale_type is None else\
|
||||||
|
rand_data((n,), types.channel_scale_type)
|
||||||
|
w_tok_s = None if types.token_scale_type is None else\
|
||||||
|
rand_data((m,), types.token_scale_type)
|
||||||
|
|
||||||
|
return Tensors(w_ref=w_ref,
|
||||||
|
a_ref=a_ref,
|
||||||
|
a=a,
|
||||||
|
w_q=w_q_packed,
|
||||||
|
w_g_s=w_s,
|
||||||
|
w_g_zp=maybe_convert_zeropoints(w_zp, w_s),
|
||||||
|
w_ch_s=w_ch_s,
|
||||||
|
w_tok_s=w_tok_s)
|
||||||
|
|
||||||
|
|
||||||
|
# None stype means scales use the same dtype as a
|
||||||
|
def machete_mm_test_helper(types: TypeConfig,
|
||||||
|
tensors: Tensors,
|
||||||
|
group_size: Optional[int] = None,
|
||||||
|
schedule: Optional[str] = None):
|
||||||
|
output_ref = torch.matmul(tensors.a_ref, tensors.w_ref)
|
||||||
|
output_ref_type = output_ref.dtype
|
||||||
|
|
||||||
|
if tensors.w_ch_s is not None:
|
||||||
|
output_ref = (output_ref.to(tensors.w_ch_s.dtype) *
|
||||||
|
tensors.w_ch_s.unsqueeze(0)).to(output_ref_type)
|
||||||
|
if tensors.w_tok_s is not None:
|
||||||
|
output_ref = (output_ref.to(tensors.w_tok_s.dtype) *
|
||||||
|
tensors.w_tok_s.unsqueeze(1)).to(output_ref_type)
|
||||||
|
|
||||||
|
output = ops.machete_mm(
|
||||||
|
a=tensors.a,
|
||||||
|
b_q=tensors.w_q,
|
||||||
|
b_type=types.weight_type,
|
||||||
|
b_group_scales=tensors.w_g_s,
|
||||||
|
b_group_zeros=tensors.w_g_zp,
|
||||||
|
b_group_size=group_size,
|
||||||
|
b_channel_scales=tensors.w_ch_s,
|
||||||
|
a_token_scales=tensors.w_tok_s,
|
||||||
|
out_type=types.output_type,
|
||||||
|
schedule=schedule,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(output)
|
||||||
|
print(output_ref)
|
||||||
|
|
||||||
|
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||||
|
# Relax atol when we have zeropoints since the way machete applies
|
||||||
|
# zeropoints (after scales) causes noise around 0
|
||||||
|
atol = 1 if tensors.w_g_zp is not None\
|
||||||
|
else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1)
|
||||||
|
rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1
|
||||||
|
torch.testing.assert_close(output,
|
||||||
|
output_ref.to(output.dtype),
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||||
|
reason="Machete is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("shape",
|
||||||
|
MNK_SHAPES,
|
||||||
|
ids=lambda x: "x".join(str(v) for v in x))
|
||||||
|
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||||
|
def test_machete_all_schedules(shape, types: TypeConfig):
|
||||||
|
|
||||||
|
group_sizes: List[Optional[int]] = []
|
||||||
|
if types.group_scale_type is None:
|
||||||
|
group_sizes = [None]
|
||||||
|
else:
|
||||||
|
group_sizes = GROUP_SIZES_TO_TEST
|
||||||
|
|
||||||
|
for group_size in group_sizes:
|
||||||
|
if not group_size_valid(shape, group_size):
|
||||||
|
continue
|
||||||
|
|
||||||
|
tensors = create_test_tensors(shape, types, group_size)
|
||||||
|
print(f"MNK = {shape}")
|
||||||
|
for schedule in ops.machete_supported_schedules(
|
||||||
|
types.act_type,
|
||||||
|
types.weight_type,
|
||||||
|
group_scales_type=types.group_scale_type,
|
||||||
|
group_zeros_type=types.group_scale_type,
|
||||||
|
out_type=types.output_type):
|
||||||
|
print(f"Testing schedule {schedule}")
|
||||||
|
machete_mm_test_helper(types, tensors, group_size, schedule)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||||
|
reason="Machete is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("shape",
|
||||||
|
MNK_SHAPES,
|
||||||
|
ids=lambda x: "x".join(str(v) for v in x))
|
||||||
|
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||||
|
def test_machete_heuristic(shape, types: TypeConfig):
|
||||||
|
group_sizes: List[Optional[int]] = []
|
||||||
|
if types.group_scale_type is None:
|
||||||
|
group_sizes = [None]
|
||||||
|
else:
|
||||||
|
group_sizes = GROUP_SIZES_TO_TEST
|
||||||
|
|
||||||
|
for group_size in group_sizes:
|
||||||
|
if not group_size_valid(shape, group_size):
|
||||||
|
continue
|
||||||
|
|
||||||
|
tensors = create_test_tensors(shape, types, group_size)
|
||||||
|
machete_mm_test_helper(types, tensors, group_size)
|
||||||
|
|
||||||
|
|
||||||
|
# Test working on other devices
|
||||||
|
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||||
|
reason="Machete is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_machete_devices(device: str):
|
||||||
|
group_size = 128
|
||||||
|
|
||||||
|
type_config = TypeConfig(act_type=torch.float16,
|
||||||
|
weight_type=scalar_types.uint4b8,
|
||||||
|
output_type=None,
|
||||||
|
group_scale_type=torch.float16,
|
||||||
|
group_zero_type=None,
|
||||||
|
channel_scale_type=None,
|
||||||
|
token_scale_type=None)
|
||||||
|
|
||||||
|
tensors = create_test_tensors((512, 4096, 4096), type_config, group_size)
|
||||||
|
|
||||||
|
for field in fields(Tensors):
|
||||||
|
tensor = getattr(tensors, field.name)
|
||||||
|
if isinstance(tensor, torch.Tensor):
|
||||||
|
setattr(tensors, field.name, tensor.to(device))
|
||||||
|
|
||||||
|
machete_mm_test_helper(type_config, tensors, group_size)
|
||||||
|
|
||||||
|
|
||||||
|
# Test working with a subset of A and B
|
||||||
|
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||||
|
reason="Machete is not supported on this GPU type.")
|
||||||
|
def test_machete_subset():
|
||||||
|
group_size = 128
|
||||||
|
|
||||||
|
type_config = TypeConfig(act_type=torch.float16,
|
||||||
|
weight_type=scalar_types.uint4b8,
|
||||||
|
output_type=None,
|
||||||
|
group_scale_type=torch.float16,
|
||||||
|
group_zero_type=None,
|
||||||
|
channel_scale_type=None,
|
||||||
|
token_scale_type=None)
|
||||||
|
|
||||||
|
tensors = create_test_tensors((512, 4096, 4096),
|
||||||
|
type_config,
|
||||||
|
group_size,
|
||||||
|
subset_stride_factor=2)
|
||||||
|
machete_mm_test_helper(type_config, tensors, group_size)
|
||||||
|
|
||||||
|
|
||||||
|
# Test to make sure cuda graphs work
|
||||||
|
class MacheteLayer(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return ops.machete_mm(a=a, **self.kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||||
|
reason="Machete is not supported on this GPU type.")
|
||||||
|
def test_machete_cuda_graph():
|
||||||
|
m, n, k = 512, 4096, 4096
|
||||||
|
|
||||||
|
a = rand_data((m, k), torch.float16)
|
||||||
|
b = rand_data((k, n), torch.float16)
|
||||||
|
wtype = scalar_types.uint4b8
|
||||||
|
stype = torch.float16
|
||||||
|
group_size = 128
|
||||||
|
zero_points = False
|
||||||
|
|
||||||
|
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
||||||
|
a.dtype, b, wtype, stype, group_size, zero_points)
|
||||||
|
|
||||||
|
# Construct a trivial model with a single layer that calls a machete kernel
|
||||||
|
model = MacheteLayer(
|
||||||
|
b_q=w_q_packed,
|
||||||
|
b_type=wtype,
|
||||||
|
b_group_scales=w_s,
|
||||||
|
b_group_zeros=maybe_convert_zeropoints(w_zp, w_s),
|
||||||
|
b_group_size=group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_ref = torch.matmul(a, w_ref)
|
||||||
|
|
||||||
|
# Run the model with a cuda graph
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
g = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(g):
|
||||||
|
output = model(a)
|
||||||
|
output.zero_()
|
||||||
|
g.replay()
|
||||||
|
|
||||||
|
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||||
|
# Relax atol when we have zeropoints since the way machete applies
|
||||||
|
# zeropoints (after scales) causes noise around 0
|
||||||
|
atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
|
||||||
|
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)
|
@ -444,18 +444,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
size_k: torch.SymInt) -> torch.Tensor:
|
size_k: torch.SymInt) -> torch.Tensor:
|
||||||
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
||||||
|
|
||||||
@register_fake("_C::machete_gemm")
|
@register_fake("_C::machete_mm")
|
||||||
def machete_gemm_fake(
|
def machete_mm_fake(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
# Should be the tensor returned by machete_prepack_B
|
# b_q Should be the tensor returned by machete_prepack_B
|
||||||
b_q: torch.Tensor,
|
b_q: torch.Tensor,
|
||||||
b_type: ScalarType,
|
b_type: ScalarType,
|
||||||
b_scales: Optional[torch.Tensor] = None,
|
out_type: Optional[torch.dtype] = None,
|
||||||
b_zeros: Optional[torch.Tensor] = None,
|
b_group_scales: Optional[torch.Tensor] = None,
|
||||||
|
b_group_zeros: Optional[torch.Tensor] = None,
|
||||||
b_group_size: Optional[int] = None,
|
b_group_size: Optional[int] = None,
|
||||||
c: Optional[torch.Tensor] = None,
|
b_channel_scales: Optional[torch.Tensor] = None,
|
||||||
alpha: Optional[float] = None,
|
a_token_scales: Optional[torch.Tensor] = None,
|
||||||
beta: Optional[float] = None,
|
|
||||||
schedule: Optional[str] = None,
|
schedule: Optional[str] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
m = a.size(0)
|
m = a.size(0)
|
||||||
@ -463,8 +463,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
return torch.empty((m, n), device=a.device, dtype=a.dtype)
|
return torch.empty((m, n), device=a.device, dtype=a.dtype)
|
||||||
|
|
||||||
@register_fake("_C::machete_prepack_B")
|
@register_fake("_C::machete_prepack_B")
|
||||||
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
|
def machete_prepack_B_fake(
|
||||||
b_type: ScalarType) -> torch.Tensor:
|
b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
|
||||||
|
group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
|
||||||
return torch.empty_like(b_q_weight,
|
return torch.empty_like(b_q_weight,
|
||||||
memory_format=torch.contiguous_format)
|
memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
@ -617,29 +618,41 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
# machete
|
# machete
|
||||||
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
|
def machete_supported_schedules(
|
||||||
return torch.ops._C.machete_supported_schedules(b_type.id)
|
a_type: torch.dtype,
|
||||||
|
b_type: ScalarType,
|
||||||
|
group_scales_type: Optional[torch.dtype],
|
||||||
|
group_zeros_type: Optional[torch.dtype] = None,
|
||||||
|
channel_scales_type: Optional[torch.dtype] = None,
|
||||||
|
token_scales_type: Optional[torch.dtype] = None,
|
||||||
|
out_type: Optional[torch.dtype] = None) -> List[str]:
|
||||||
|
return torch.ops._C.machete_supported_schedules(
|
||||||
|
a_type, b_type.id, group_scales_type, group_zeros_type,
|
||||||
|
channel_scales_type, token_scales_type, out_type)
|
||||||
|
|
||||||
|
|
||||||
def machete_gemm(
|
def machete_mm(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B
|
# b_q Should be the tensor returned by machete_prepack_B
|
||||||
b_type: ScalarType,
|
b_q: torch.Tensor,
|
||||||
b_scales: Optional[torch.Tensor] = None,
|
b_type: ScalarType,
|
||||||
b_zeros: Optional[torch.Tensor] = None,
|
out_type: Optional[torch.dtype] = None,
|
||||||
b_group_size: Optional[int] = None,
|
b_group_scales: Optional[torch.Tensor] = None,
|
||||||
c: Optional[torch.Tensor] = None,
|
b_group_zeros: Optional[torch.Tensor] = None,
|
||||||
alpha: Optional[float] = None,
|
b_group_size: Optional[int] = None,
|
||||||
beta: Optional[float] = None,
|
b_channel_scales: Optional[torch.Tensor] = None,
|
||||||
schedule: Optional[str] = None,
|
a_token_scales: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
schedule: Optional[str] = None) -> torch.Tensor:
|
||||||
return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros,
|
return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales,
|
||||||
b_group_size, c, alpha, beta, schedule)
|
b_group_zeros, b_group_size,
|
||||||
|
b_channel_scales, a_token_scales, schedule)
|
||||||
|
|
||||||
|
|
||||||
def machete_prepack_B(b_q_weight: torch.Tensor,
|
def machete_prepack_B(
|
||||||
b_type: ScalarType) -> torch.Tensor:
|
b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
|
||||||
return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id)
|
group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
|
||||||
|
return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id,
|
||||||
|
group_scales_type)
|
||||||
|
|
||||||
|
|
||||||
if hasattr(torch.ops._C, "permute_cols"):
|
if hasattr(torch.ops._C, "permute_cols"):
|
||||||
|
@ -79,7 +79,9 @@ class MacheteLinearKernel(MPLinearKernel):
|
|||||||
c.weight_type,
|
c.weight_type,
|
||||||
packed_dim=0)
|
packed_dim=0)
|
||||||
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
|
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
|
||||||
self.config.weight_type)
|
a_type=c.act_type,
|
||||||
|
b_type=c.weight_type,
|
||||||
|
group_scales_type=c.act_type)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def transform_w_s(x):
|
def transform_w_s(x):
|
||||||
@ -105,12 +107,12 @@ class MacheteLinearKernel(MPLinearKernel):
|
|||||||
if c.has_g_idx:
|
if c.has_g_idx:
|
||||||
x_2d = self.act_perm(x_2d)
|
x_2d = self.act_perm(x_2d)
|
||||||
|
|
||||||
output = ops.machete_gemm(a=x_2d,
|
output = ops.machete_mm(a=x_2d,
|
||||||
b_q=w_q,
|
b_q=w_q,
|
||||||
b_type=c.weight_type,
|
b_type=c.weight_type,
|
||||||
b_zeros=None,
|
b_group_zeros=None,
|
||||||
b_scales=w_s,
|
b_group_scales=w_s,
|
||||||
b_group_size=c.group_size)
|
b_group_size=c.group_size)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output.add_(bias) # In-place add
|
output.add_(bias) # In-place add
|
||||||
|
@ -126,11 +126,14 @@ def permute_rows(q_w: torch.Tensor,
|
|||||||
|
|
||||||
def quantize_weights(w: torch.Tensor,
|
def quantize_weights(w: torch.Tensor,
|
||||||
quant_type: ScalarType,
|
quant_type: ScalarType,
|
||||||
group_size: int,
|
group_size: Optional[int],
|
||||||
zero_points: bool = False,
|
zero_points: bool = False,
|
||||||
ref_zero_points_after_scales: bool = False):
|
ref_zero_points_after_scales: bool = False):
|
||||||
assert quant_type.is_integer(), \
|
assert quant_type.is_integer(), \
|
||||||
"Floating point quantization may work but has not been tested"
|
"Floating point quantization may work but has not been tested"
|
||||||
|
assert not zero_points or group_size is not None, \
|
||||||
|
"to have group zero points, group_size must be provided "\
|
||||||
|
"(-1 group_size is channelwise)"
|
||||||
|
|
||||||
orig_device = w.device
|
orig_device = w.device
|
||||||
orig_type = w.dtype
|
orig_type = w.dtype
|
||||||
@ -140,10 +143,9 @@ def quantize_weights(w: torch.Tensor,
|
|||||||
|
|
||||||
if group_size == -1:
|
if group_size == -1:
|
||||||
group_size = size_k
|
group_size = size_k
|
||||||
assert group_size <= size_k
|
|
||||||
|
|
||||||
# Reshape to [groupsize, -1]
|
# Reshape to [groupsize, -1]
|
||||||
if group_size < size_k:
|
if group_size is not None and group_size < size_k:
|
||||||
w = w.reshape((-1, group_size, size_n))
|
w = w.reshape((-1, group_size, size_n))
|
||||||
w = w.permute(1, 0, 2)
|
w = w.permute(1, 0, 2)
|
||||||
w = w.reshape((group_size, -1))
|
w = w.reshape((group_size, -1))
|
||||||
@ -155,18 +157,20 @@ def quantize_weights(w: torch.Tensor,
|
|||||||
max_q_val = quant_type.max()
|
max_q_val = quant_type.max()
|
||||||
min_q_val = quant_type.min()
|
min_q_val = quant_type.min()
|
||||||
|
|
||||||
if zero_points:
|
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
||||||
assert not quant_type.is_signed() and quant_type.max() > 0
|
maybe_w_zp = None
|
||||||
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
if group_size is not None:
|
||||||
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
|
if zero_points:
|
||||||
.clamp(min_q_val, max_q_val).int()
|
assert not quant_type.is_signed() and quant_type.max() > 0
|
||||||
else:
|
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
||||||
# If the bias is such that there are no possible negative/positive
|
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
|
||||||
# values, set the max value to inf to avoid divide by 0
|
.clamp(min_q_val, max_q_val).int()
|
||||||
w_s = torch.max(
|
else:
|
||||||
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
# If the bias is such that there are no possible negative/positive
|
||||||
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
|
# values, set the max value to inf to avoid divide by 0
|
||||||
maybe_w_zp = None
|
w_s = torch.max(
|
||||||
|
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
||||||
|
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
|
||||||
|
|
||||||
# Quantize
|
# Quantize
|
||||||
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
||||||
@ -176,7 +180,7 @@ def quantize_weights(w: torch.Tensor,
|
|||||||
# For some kernels (namely Machete) the zero-points are applied after the
|
# For some kernels (namely Machete) the zero-points are applied after the
|
||||||
# scales are applied, for this case computing the reference in similar way
|
# scales are applied, for this case computing the reference in similar way
|
||||||
# allows us to use tighter error tolerances in our unit tests.
|
# allows us to use tighter error tolerances in our unit tests.
|
||||||
if ref_zero_points_after_scales and zero_points:
|
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
||||||
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
||||||
else:
|
else:
|
||||||
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
||||||
@ -185,7 +189,7 @@ def quantize_weights(w: torch.Tensor,
|
|||||||
w_q += quant_type.bias
|
w_q += quant_type.bias
|
||||||
|
|
||||||
# Restore original shapes
|
# Restore original shapes
|
||||||
if group_size < size_k:
|
if group_size is not None and group_size < size_k:
|
||||||
|
|
||||||
def reshape_w(w):
|
def reshape_w(w):
|
||||||
w = w.reshape((group_size, -1, size_n))
|
w = w.reshape((group_size, -1, size_n))
|
||||||
@ -195,17 +199,16 @@ def quantize_weights(w: torch.Tensor,
|
|||||||
|
|
||||||
w_q = reshape_w(w_q)
|
w_q = reshape_w(w_q)
|
||||||
w_ref = reshape_w(w_ref)
|
w_ref = reshape_w(w_ref)
|
||||||
|
w_s = w_s.reshape((-1, size_n)).contiguous()
|
||||||
|
|
||||||
w_s = w_s.reshape((-1, size_n)).contiguous()
|
if maybe_w_zp is not None:
|
||||||
|
|
||||||
if zero_points:
|
|
||||||
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
||||||
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
w_ref.to(device=orig_device),
|
w_ref.to(device=orig_device),
|
||||||
w_q.to(device=orig_device),
|
w_q.to(device=orig_device),
|
||||||
w_s.to(device=orig_device),
|
w_s if group_size is not None else None,
|
||||||
maybe_w_zp,
|
maybe_w_zp,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user