[Kernel] Add per-tensor and per-token AZP epilogues (#5941)
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
5c60c8c423
commit
8d59dbb000
@ -32,7 +32,6 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
@ -44,59 +43,18 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
|
||||
# impl
|
||||
|
||||
|
||||
def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.mm(a, b)
|
||||
|
||||
|
||||
def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
|
||||
def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
use_fast_accum=True)
|
||||
|
||||
|
||||
def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype) -> torch.Tensor:
|
||||
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, out_dtype: torch.dtype, label: str,
|
||||
sub_label: str, fn: Callable, description: str) -> TMeasurement:
|
||||
|
||||
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
**kwargs) -> TMeasurement:
|
||||
min_run_time = 1
|
||||
|
||||
globals = {
|
||||
"a": a,
|
||||
"b": b,
|
||||
"scale_a": scale_a,
|
||||
"scale_b": scale_b,
|
||||
"out_dtype": out_dtype,
|
||||
"args": args,
|
||||
"kwargs": kwargs,
|
||||
"fn": fn,
|
||||
}
|
||||
return TBenchmark.Timer(
|
||||
stmt="fn(a, b, scale_a, scale_b, out_dtype)",
|
||||
stmt="fn(*args, **kwargs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -110,26 +68,58 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
|
||||
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
|
||||
|
||||
timers = []
|
||||
# pytorch impl - bfloat16
|
||||
timers.append(
|
||||
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
||||
torch.bfloat16, label, sub_label, pytorch_mm_impl,
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16),
|
||||
b.to(dtype=torch.bfloat16)))
|
||||
|
||||
# pytorch impl - float16
|
||||
timers.append(
|
||||
bench_fn(a.to(dtype=torch.float16, device="cuda"),
|
||||
b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b,
|
||||
torch.float16, label, sub_label, pytorch_mm_impl,
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales"))
|
||||
bench_fn(label, sub_label,
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
|
||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
|
||||
|
||||
# cutlass impl
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm"))
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
|
||||
# cutlass with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias))
|
||||
|
||||
# cutlass with azp per-tensor
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj))
|
||||
|
||||
# cutlass with azp per-tensor + bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj, None, bias))
|
||||
|
||||
# cutlass with azp per-token
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj, azp))
|
||||
|
||||
# cutlass with azp per-token + bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj, azp, bias))
|
||||
|
||||
return timers
|
||||
|
||||
@ -140,46 +130,88 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
timers = []
|
||||
|
||||
# pytorch impl w. bf16
|
||||
timers.append(
|
||||
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
||||
torch.bfloat16, label, sub_label, pytorch_mm_impl,
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda")))
|
||||
|
||||
# pytorch impl: bf16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm"))
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16))
|
||||
|
||||
# pytorch impl: bf16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
pytorch_fp8_impl_fast_accum,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"))
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True))
|
||||
|
||||
# pytorch impl: fp16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||
pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm"))
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16))
|
||||
|
||||
# pytorch impl: fp16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||
pytorch_fp8_impl_fast_accum,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"))
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16,
|
||||
use_fast_accum=True))
|
||||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm"))
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
# cutlass impl: fp16 output
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||
cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm"))
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16))
|
||||
|
||||
# cutlass impl: bf16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias))
|
||||
|
||||
# cutlass impl: fp16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16,
|
||||
bias.to(dtype=torch.float16)))
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
@ -200,7 +232,6 @@ def print_timers(timers: Iterable[TMeasurement]):
|
||||
|
||||
def run(dtype: torch.dtype,
|
||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||
@ -216,7 +247,6 @@ def make_output(data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None):
|
||||
|
||||
print(f"== All Results {base_description} ====")
|
||||
print_timers(data)
|
||||
|
||||
@ -251,7 +281,6 @@ def run_range_bench(args):
|
||||
|
||||
|
||||
def run_model_bench(args):
|
||||
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
@ -128,6 +128,14 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
|
||||
torch::Tensor const& b_q_weight,
|
||||
torch::Tensor const& s_tok,
|
||||
|
147
csrc/quantization/cutlass_w8a8/Epilogues.md
Normal file
147
csrc/quantization/cutlass_w8a8/Epilogues.md
Normal file
@ -0,0 +1,147 @@
|
||||
# CUTLASS Epilogues
|
||||
|
||||
## Introduction
|
||||
This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.
|
||||
|
||||
Currently, we only support symmetric quantization for weights,
|
||||
and symmetric and asymmetric quantization for activations.
|
||||
Both can be quantized per-tensor or per-channel (weights) / per-token (activations).
|
||||
|
||||
There are 4 epilogues:
|
||||
1. ScaledEpilogue: symmetric quantization for activations, no bias.
|
||||
1. ScaledEpilogueBias: symmetric quantization for activations, supports bias.
|
||||
1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias.
|
||||
1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias.
|
||||
|
||||
We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
|
||||
Instead, if no bias is passed, the epilogue will use 0 as the bias.
|
||||
That induces a redundant addition operation (and runtime check), but the performance impact is minor.
|
||||
|
||||
## Underlying Linear Algebra
|
||||
|
||||
More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975).
|
||||
|
||||
If $` \widehat X `$ is the quantized $` X `$, our matrices become the following
|
||||
|
||||
```math
|
||||
A = s_a (\widehat A - J_a z_a)
|
||||
```
|
||||
```math
|
||||
B = s_b \widehat B
|
||||
```
|
||||
```math
|
||||
D = A B + C
|
||||
```
|
||||
```math
|
||||
D = s_a s_b \widehat D + C
|
||||
```
|
||||
|
||||
Here, D is the output of the GEMM, and C is the bias.
|
||||
A is the activations and supports asymmetric quantization,
|
||||
and B is the weights and only supports symmetric quantization.
|
||||
$ s_a $ and $s_b$ are the scales for activations and weights, respectively.
|
||||
$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A.
|
||||
Additional epilogues would be required to support asymmetric quantization for weights.
|
||||
|
||||
Expanding further, we can calculate $` \widehat D `$ as follows:
|
||||
|
||||
```math
|
||||
A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
|
||||
```
|
||||
```math
|
||||
A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
|
||||
```
|
||||
```math
|
||||
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
|
||||
```
|
||||
|
||||
Note that $` \widehat A \widehat B `$ is the raw output of the GEMM,
|
||||
and $` J_a \widehat B `$ is known ahead of time.
|
||||
Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$.
|
||||
|
||||
## Epilogues
|
||||
|
||||
### ScaledEpilogue
|
||||
This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$.
|
||||
The output of the GEMM is:
|
||||
|
||||
```math
|
||||
\widehat D = \widehat A \widehat B
|
||||
```
|
||||
```math
|
||||
D = s_a s_b \widehat D
|
||||
```
|
||||
```math
|
||||
D = s_a s_b \widehat A \widehat B
|
||||
```
|
||||
|
||||
Epilogue parameters:
|
||||
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
||||
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
||||
|
||||
### ScaledEpilogueBias
|
||||
This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$.
|
||||
The output of the GEMM is:
|
||||
|
||||
```math
|
||||
\widehat D = \widehat A \widehat B
|
||||
```
|
||||
```math
|
||||
D = s_a s_b \widehat D + C
|
||||
```
|
||||
```math
|
||||
D = s_a s_b \widehat A \widehat B + C
|
||||
```
|
||||
|
||||
|
||||
Epilogue parameters:
|
||||
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
||||
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
||||
- `bias` is the bias, is always per-channel (row-vector).
|
||||
|
||||
### ScaledEpilogueAzp
|
||||
This epilogue computes the asymmetric per-tensor quantization for activations with bias.
|
||||
The output of the GEMM is:
|
||||
|
||||
```math
|
||||
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
|
||||
```
|
||||
```math
|
||||
D = s_a s_b \widehat D + C
|
||||
```
|
||||
```math
|
||||
D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
|
||||
```
|
||||
|
||||
Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$.
|
||||
That is precomputed and stored in `azp_with_adj` as a row-vector.
|
||||
|
||||
Epilogue parameters:
|
||||
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
||||
- Generally this will be per-tensor as the zero-points are per-tensor.
|
||||
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
||||
- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector).
|
||||
- `bias` is the bias, is always per-channel (row-vector).
|
||||
|
||||
To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel.
|
||||
|
||||
### ScaledEpilogueAzpPerToken
|
||||
This epilogue computes the asymmetric per-token quantization for activations with bias.
|
||||
|
||||
The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector.
|
||||
That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$.
|
||||
|
||||
Epilogue parameters:
|
||||
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
||||
- Generally this will be per-token as the zero-points are per-token.
|
||||
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
||||
- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector).
|
||||
- `azp` is the zero-point (`z_a`), is per-token (column-vector).
|
||||
- `bias` is the bias, is always per-channel (row-vector).
|
||||
|
||||
To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel.
|
||||
|
||||
The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM):
|
||||
```
|
||||
out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
|
||||
```
|
@ -207,6 +207,156 @@ struct VisitorRowOrScalarBroadcast {
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL
|
||||
>
|
||||
struct VisitorRowOrZeroBroadcast {
|
||||
|
||||
// This struct has been modified to remove null_default (because it's always 0)
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage {};
|
||||
|
||||
// Global load type
|
||||
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrZeroBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gRow,
|
||||
RTensor&& tC_rRow,
|
||||
CTensor&& tC_cRow,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gRow(cute::forward<GTensor>(tC_gRow)),
|
||||
tC_rRow(cute::forward<RTensor>(tC_rRow)),
|
||||
tC_cRow(cute::forward<CTensor>(tC_cRow)),
|
||||
n(get<1>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gRow;
|
||||
RTensor tC_rRow;
|
||||
CTensor tC_cRow;
|
||||
Params const* params_ptr;
|
||||
int n;
|
||||
|
||||
// This function is modified from VisitorRowBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rRow);
|
||||
auto src_v = filter(tC_gRow);
|
||||
auto coord_v = filter(tC_cRow);
|
||||
auto dst_v = filter(tC_rRow);
|
||||
|
||||
if (params_ptr->ptr_row != nullptr) {
|
||||
// In this case we are loading from a row vector and broadcasting
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
bool guard = get<1>(coord_v(i)) < n;
|
||||
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
||||
dst_v(i), (void const*)&src_v(i), guard);
|
||||
}
|
||||
} else {
|
||||
// In this case we are broadcasting 0
|
||||
VecType filled_vec;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VecLength; i++) {
|
||||
reinterpret_cast<Element*>(&filled_vec)[i] = Element{0};
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
if (get<1>(coord_v(i)) < n) {
|
||||
dst_v(i) = filled_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
|
||||
return rRow_frg(column_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mRow = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_row),
|
||||
problem_shape,
|
||||
params_ptr->dRow);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN
|
||||
Tensor tC_gRow = recast<VecType>(
|
||||
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
|
||||
)(_,_,_0{},_0{},_0{},_0{});
|
||||
Tensor tC_rRow = make_tensor_like(tC_gRow);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cRow = make_identity_tensor(mRow.shape());
|
||||
Tensor tC_cRow = outer_partition(
|
||||
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
|
||||
Shape<Int<VecLength>>{},
|
||||
(_0{})
|
||||
);
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gRow), decltype(tC_rRow),
|
||||
decltype(tC_cRow), ProblemShape>(
|
||||
cute::move(tC_gRow),
|
||||
cute::move(tC_rRow),
|
||||
cute::move(tC_cRow),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
@ -217,7 +367,7 @@ template<
|
||||
>
|
||||
struct VisitorColOrScalarBroadcast {
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast.
|
||||
struct Arguments {
|
||||
Element const* ptr_col = nullptr;
|
||||
|
@ -50,6 +50,25 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
@ -87,6 +106,25 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
@ -139,3 +177,22 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
@ -73,19 +73,63 @@ struct enable_sm89_to_sm90 : Kernel {
|
||||
};
|
||||
|
||||
/*
|
||||
* This class provides the common ScaleA and ScaleB descriptors for the
|
||||
* ScaledEpilogue and ScaledEpilogueBias classes.
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
template <typename T>
|
||||
using ColOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
template <typename T>
|
||||
using RowOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrZeroLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
} else {
|
||||
// it would technically work but no use case as data_ptr is never nullptr
|
||||
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
}
|
||||
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
@ -110,8 +154,8 @@ struct ScaledEpilogue
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::ScaleA;
|
||||
using ScaleB = typename SUPER::ScaleB;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
@ -131,28 +175,32 @@ struct ScaledEpilogue
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using ScaleAArgs = typename ScaleA::Arguments;
|
||||
using ScaleBArgs = typename ScaleB::Arguments;
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
|
||||
typename EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||
|
||||
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
|
||||
return evt_compute_args;
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||
* This bias can also be used in the per-tensor azp case, where the activation
|
||||
* zero point (azp) is used to compute an azp correction term,
|
||||
* which is folded into the bias.
|
||||
*
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBias
|
||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
protected:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::ScaleA;
|
||||
using ScaleB = typename SUPER::ScaleB;
|
||||
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
@ -164,30 +212,163 @@ struct ScaledEpilogueBias
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, ElementD, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||
EVTCompute0, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
using ScaleAArgs = typename ScaleA::Arguments;
|
||||
using ScaleBArgs = typename ScaleB::Arguments;
|
||||
using BiasArgs = typename Bias::Arguments;
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
BiasArgs bias_args{static_cast<ElementD*>(bias.data_ptr()), {}};
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
typename EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
* term, which should already be multiplied with the scalar azp.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
|
||||
bias_args};
|
||||
return evt_compute_args;
|
||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue supports per-token azp by computing and applying
|
||||
* the correction term using a rank-1 update. If the term were materialized,
|
||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||
* point for each row of A.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
// Per-token azp term, shape (m,1)
|
||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||
|
||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -58,21 +58,63 @@ struct enable_sm90_or_later : Kernel {
|
||||
};
|
||||
|
||||
/*
|
||||
* This class provides the common ScaleA and ScaleB descriptors for the
|
||||
* ScaledEpilogue and ScaledEpilogueBias classes.
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
||||
template <typename T>
|
||||
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
||||
template <typename T>
|
||||
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
} else {
|
||||
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
|
||||
!std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
}
|
||||
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
@ -97,8 +139,8 @@ struct ScaledEpilogue
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::ScaleA;
|
||||
using ScaleB = typename SUPER::ScaleB;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
@ -118,24 +160,32 @@ struct ScaledEpilogue
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using ScaleA_Args = typename ScaleA::Arguments;
|
||||
using ScaleB_Args = typename ScaleB::Arguments;
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
|
||||
return ArgumentType{a_args, {b_args}};
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||
* This bias can also be used in the per-tensor azp case, where the activation
|
||||
* zero point (azp) is used to compute an azp correction term,
|
||||
* which is folded into the bias.
|
||||
*
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBias
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::ScaleA;
|
||||
using ScaleB = typename SUPER::ScaleB;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
@ -148,27 +198,160 @@ struct ScaledEpilogueBias
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD,
|
||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
* term, which should already be multiplied with the scalar azp.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||
|
||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
using ScaleA_Args = typename ScaleA::Arguments;
|
||||
using ScaleB_Args = typename ScaleB::Arguments;
|
||||
using Bias_Args = typename Bias::Arguments;
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
Bias_Args bias_args{static_cast<ElementD*>(bias.data_ptr())};
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
return ArgumentType{a_args, {b_args}, bias_args};
|
||||
/*
|
||||
* This epilogue supports per-token azp by computing and applying
|
||||
* the correction term using a rank-1 update. If the term were materialized,
|
||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||
* point for each row of A.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||
|
||||
// Per-token azp term, shape (m,1)
|
||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||
|
||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
@ -546,4 +729,23 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -29,6 +29,40 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS FP8 kernels need at least
|
||||
// CUDA 12.0 on SM90 systems (Hopper)
|
||||
@ -45,18 +79,20 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
return false;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
int32_t major_capability;
|
||||
int32_t minor_capability;
|
||||
int32_t get_sm_version_num() {
|
||||
int32_t major_capability, minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
0);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
0);
|
||||
int32_t version_num = major_capability * 10 + minor_capability;
|
||||
return version_num;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
@ -77,7 +113,7 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
}
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 90) {
|
||||
// Hopper
|
||||
|
||||
@ -99,3 +135,64 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
// bias, azp, azp_adj are all 1d
|
||||
// bias and azp_adj have n elements, azp has m elements
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
||||
}
|
||||
if (azp) {
|
||||
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
||||
}
|
||||
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
||||
|
||||
// azp & bias types
|
||||
TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 90) {
|
||||
// Hopper
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
#else
|
||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
#endif
|
||||
} else if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
} else if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
} else {
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
}
|
||||
}
|
@ -166,13 +166,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
|
||||
|
||||
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
|
@ -28,13 +28,16 @@ def to_int8(tensor: torch.Tensor):
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def rand_int8(shape: tuple, device: str = "cuda"):
|
||||
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
||||
|
||||
|
||||
def baseline_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
output = (scale_a * (scale_b * (torch.mm(
|
||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
||||
if bias is not None:
|
||||
@ -221,6 +224,121 @@ def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
||||
use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
@pytest.mark.parametrize("n", [16, 32, 64])
|
||||
@pytest.mark.parametrize("k", [64, 128, 256])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skip
|
||||
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
||||
out_dtype: torch.dtype):
|
||||
# Currently, the test is failing because folding azp into
|
||||
# 16-bit bias loses too much precision
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
aq_i8 = rand_int8((m, k))
|
||||
bq_i8 = rand_int8((n, k)).t()
|
||||
|
||||
aq_i32 = aq_i8.to(dtype=torch.int32)
|
||||
bq_i32 = bq_i8.to(dtype=torch.int32)
|
||||
|
||||
aq_f32 = aq_i8.to(dtype=torch.float32)
|
||||
bq_f32 = bq_i8.to(dtype=torch.float32)
|
||||
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
|
||||
assert torch.allclose(a_dq, scale_a * aq_f32 + azp_a)
|
||||
|
||||
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
|
||||
|
||||
J = torch.ones((1, k), device="cuda", dtype=torch.float32)
|
||||
azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
|
||||
assert azp_bias.shape == (1, n)
|
||||
assert azp_bias[0, :].shape == (n, )
|
||||
|
||||
baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
|
||||
(aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
|
||||
dtype=out_dtype, device='cuda')
|
||||
|
||||
out = ops.cutlass_scaled_mm(aq_i8,
|
||||
bq_i8,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=azp_bias[0, :])
|
||||
assert torch.allclose(out, baseline_dq, rtol=1e-2, atol=1e0)
|
||||
assert torch.allclose(out, baseline_q, rtol=1e-2, atol=1e0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
@pytest.mark.parametrize("n", [16, 32, 64])
|
||||
@pytest.mark.parametrize("k", [64, 128, 256])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("azp_per_token", [True, False])
|
||||
def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
use_bias: bool, azp_per_token: bool):
|
||||
m_azp = m if azp_per_token else 1
|
||||
scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
aq_i8 = rand_int8((m, k))
|
||||
aq_i32 = aq_i8.to(dtype=torch.int32)
|
||||
aq_f32 = aq_i8.to(dtype=torch.float32)
|
||||
|
||||
bq_i8 = rand_int8((n, k)).t()
|
||||
bq_i32 = bq_i8.to(dtype=torch.int32)
|
||||
bq_f32 = bq_i8.to(dtype=torch.float32)
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand(
|
||||
(m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
|
||||
assert torch.allclose(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
|
||||
else:
|
||||
bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
|
||||
|
||||
baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
|
||||
|
||||
# int32 mm not supported on CUDA
|
||||
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
|
||||
cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
|
||||
baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
|
||||
|
||||
# Hadamard is just the sum of the cols
|
||||
azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
azp_i32 = azp_aq_i8.to(dtype=torch.int32)
|
||||
func_bias = bias if use_bias else None
|
||||
|
||||
if azp_per_token:
|
||||
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
|
||||
out_dtype, azp_adj_i32, azp_i32,
|
||||
func_bias)
|
||||
else:
|
||||
azp_with_adj_i32 = azp_i32 * azp_adj_i32
|
||||
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
|
||||
out_dtype, azp_with_adj_i32, None,
|
||||
func_bias)
|
||||
|
||||
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
|
||||
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
|
||||
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
|
||||
atol = 1e-3
|
||||
assert torch.allclose(out, baseline_dq, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, baseline_q, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
def test_cutlass_subset():
|
||||
big_m, big_n, big_k = 1024, 1024, 1024
|
||||
|
@ -241,6 +241,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
assert bias is None or bias.shape[0] == b.shape[
|
||||
1] and bias.dtype == out_dtype
|
||||
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
@ -251,6 +253,28 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
||||
return out
|
||||
|
||||
|
||||
def cutlass_scaled_mm_azp(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
azp_adj: torch.Tensor,
|
||||
azp: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
assert bias is None or bias.numel(
|
||||
) == b.shape[1] and bias.dtype == out_dtype
|
||||
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
|
||||
torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
|
||||
azp, bias)
|
||||
return out
|
||||
|
||||
|
||||
# aqlm
|
||||
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
||||
codebooks: torch.Tensor, scales: torch.Tensor,
|
||||
@ -572,7 +596,7 @@ for k, v in names_and_values.items():
|
||||
if isinstance(v, fn_type) \
|
||||
and v.__code__.co_filename == __file__ \
|
||||
and any(arg is torch.Tensor or arg == "torch.Tensor"
|
||||
for arg in v.__annotations__.values()):
|
||||
for arg in v.__annotations__.values()):
|
||||
names_and_values_to_update[k] = hint_on_error(v)
|
||||
|
||||
names_and_values.update(names_and_values_to_update)
|
||||
|
Loading…
x
Reference in New Issue
Block a user