[Kernel] Add per-tensor and per-token AZP epilogues (#5941)

Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Luka Govedič 2024-08-06 14:17:08 -04:00 committed by GitHub
parent 5c60c8c423
commit 8d59dbb000
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1175 additions and 153 deletions

View File

@ -32,7 +32,6 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor:
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
@ -44,59 +43,18 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
raise ValueError("unsupported dtype")
# impl
def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return torch.mm(a, b)
def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return torch._scaled_mm(a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype)
def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return torch._scaled_mm(a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
use_fast_accum=True)
def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
# bench
def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, out_dtype: torch.dtype, label: str,
sub_label: str, fn: Callable, description: str) -> TMeasurement:
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
**kwargs) -> TMeasurement:
min_run_time = 1
globals = {
"a": a,
"b": b,
"scale_a": scale_a,
"scale_b": scale_b,
"out_dtype": out_dtype,
"args": args,
"kwargs": kwargs,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(a, b, scale_a, scale_b, out_dtype)",
stmt="fn(*args, **kwargs)",
globals=globals,
label=label,
sub_label=sub_label,
@ -110,26 +68,58 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
timers = []
# pytorch impl - bfloat16
timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_mm_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16),
b.to(dtype=torch.bfloat16)))
# pytorch impl - float16
timers.append(
bench_fn(a.to(dtype=torch.float16, device="cuda"),
b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b,
torch.float16, label, sub_label, pytorch_mm_impl,
"pytorch_fp16_fp16_fp16_matmul-no-scales"))
bench_fn(label, sub_label,
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
# cutlass impl
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm"))
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))
# cutlass with bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
bias))
# cutlass with azp per-tensor
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj))
# cutlass with azp per-tensor + bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj, None, bias))
# cutlass with azp per-token
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj, azp))
# cutlass with azp per-token + bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj, azp, bias))
return timers
@ -140,46 +130,88 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
timers = []
# pytorch impl w. bf16
timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_mm_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda")))
# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm"))
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16))
# pytorch impl: bf16 output, with fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
pytorch_fp8_impl_fast_accum,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"))
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True))
# pytorch impl: fp16 output, without fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm"))
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16))
# pytorch impl: fp16 output, with fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
pytorch_fp8_impl_fast_accum,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"))
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16,
use_fast_accum=True))
# cutlass impl: bf16 output
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm"))
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))
# cutlass impl: fp16 output
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm"))
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16))
# cutlass impl: bf16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
bias))
# cutlass impl: fp16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16,
bias.to(dtype=torch.float16)))
return timers
@ -200,7 +232,6 @@ def print_timers(timers: Iterable[TMeasurement]):
def run(dtype: torch.dtype,
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
@ -216,7 +247,6 @@ def make_output(data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
base_description: str,
timestamp=None):
print(f"== All Results {base_description} ====")
print_timers(data)
@ -251,7 +281,6 @@ def run_range_bench(args):
def run_model_bench(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")

View File

@ -128,6 +128,14 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
torch::Tensor const& b_q_weight,
torch::Tensor const& s_tok,

View File

@ -0,0 +1,147 @@
# CUTLASS Epilogues
## Introduction
This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.
Currently, we only support symmetric quantization for weights,
and symmetric and asymmetric quantization for activations.
Both can be quantized per-tensor or per-channel (weights) / per-token (activations).
There are 4 epilogues:
1. ScaledEpilogue: symmetric quantization for activations, no bias.
1. ScaledEpilogueBias: symmetric quantization for activations, supports bias.
1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias.
1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias.
We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
Instead, if no bias is passed, the epilogue will use 0 as the bias.
That induces a redundant addition operation (and runtime check), but the performance impact is minor.
## Underlying Linear Algebra
More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975).
If $` \widehat X `$ is the quantized $` X `$, our matrices become the following
```math
A = s_a (\widehat A - J_a z_a)
```
```math
B = s_b \widehat B
```
```math
D = A B + C
```
```math
D = s_a s_b \widehat D + C
```
Here, D is the output of the GEMM, and C is the bias.
A is the activations and supports asymmetric quantization,
and B is the weights and only supports symmetric quantization.
$ s_a $ and $s_b$ are the scales for activations and weights, respectively.
$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A.
Additional epilogues would be required to support asymmetric quantization for weights.
Expanding further, we can calculate $` \widehat D `$ as follows:
```math
A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
```
```math
A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
```
```math
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
```
Note that $` \widehat A \widehat B `$ is the raw output of the GEMM,
and $` J_a \widehat B `$ is known ahead of time.
Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$.
## Epilogues
### ScaledEpilogue
This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$.
The output of the GEMM is:
```math
\widehat D = \widehat A \widehat B
```
```math
D = s_a s_b \widehat D
```
```math
D = s_a s_b \widehat A \widehat B
```
Epilogue parameters:
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
### ScaledEpilogueBias
This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$.
The output of the GEMM is:
```math
\widehat D = \widehat A \widehat B
```
```math
D = s_a s_b \widehat D + C
```
```math
D = s_a s_b \widehat A \widehat B + C
```
Epilogue parameters:
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
- `bias` is the bias, is always per-channel (row-vector).
### ScaledEpilogueAzp
This epilogue computes the asymmetric per-tensor quantization for activations with bias.
The output of the GEMM is:
```math
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
```
```math
D = s_a s_b \widehat D + C
```
```math
D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
```
Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$.
That is precomputed and stored in `azp_with_adj` as a row-vector.
Epilogue parameters:
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
- Generally this will be per-tensor as the zero-points are per-tensor.
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector).
- `bias` is the bias, is always per-channel (row-vector).
To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel.
### ScaledEpilogueAzpPerToken
This epilogue computes the asymmetric per-token quantization for activations with bias.
The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector.
That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$.
Epilogue parameters:
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
- Generally this will be per-token as the zero-points are per-token.
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector).
- `azp` is the zero-point (`z_a`), is per-token (column-vector).
- `bias` is the bias, is always per-channel (row-vector).
To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel.
The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM):
```
out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
```

View File

@ -207,6 +207,156 @@ struct VisitorRowOrScalarBroadcast {
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
template<
class ThreadMap,
class Element,
class StrideMNL
>
struct VisitorRowOrZeroBroadcast {
// This struct has been modified to remove null_default (because it's always 0)
struct Arguments {
Element const* ptr_row = nullptr;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
struct SharedStorage {};
// Global load type
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
CUTLASS_HOST_DEVICE
VisitorRowOrZeroBroadcast() { }
CUTLASS_HOST_DEVICE
VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { }
Params const* params_ptr;
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
struct Callbacks : EmptyCallbacks {
CUTLASS_DEVICE
Callbacks(
GTensor&& tC_gRow,
RTensor&& tC_rRow,
CTensor&& tC_cRow,
ProblemShape problem_shape,
Params const* params_ptr
):
tC_gRow(cute::forward<GTensor>(tC_gRow)),
tC_rRow(cute::forward<RTensor>(tC_rRow)),
tC_cRow(cute::forward<CTensor>(tC_cRow)),
n(get<1>(problem_shape)),
params_ptr(params_ptr) { }
GTensor tC_gRow;
RTensor tC_rRow;
CTensor tC_cRow;
Params const* params_ptr;
int n;
// This function is modified from VisitorRowBroadcast
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rRow);
auto src_v = filter(tC_gRow);
auto coord_v = filter(tC_cRow);
auto dst_v = filter(tC_rRow);
if (params_ptr->ptr_row != nullptr) {
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
bool guard = get<1>(coord_v(i)) < n;
cutlass::arch::global_load<VecType, sizeof(VecType)>(
dst_v(i), (void const*)&src_v(i), guard);
}
} else {
// In this case we are broadcasting 0
VecType filled_vec;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VecLength; i++) {
reinterpret_cast<Element*>(&filled_vec)[i] = Element{0};
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if (get<1>(coord_v(i)) < n) {
dst_v(i) = filled_vec;
}
}
}
}
template <class ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE auto // returns an Array
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
return rRow_frg(column_idx);
}
};
template <class ProblemShape>
CUTLASS_DEVICE auto
get_callbacks(
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
ProblemShape problem_shape
) {
Tensor mRow = make_tensor(
make_gmem_ptr(params_ptr->ptr_row),
problem_shape,
params_ptr->dRow);
// VECTOR, FRAGMENT_COLUMN
Tensor tC_gRow = recast<VecType>(
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
)(_,_,_0{},_0{},_0{},_0{});
Tensor tC_rRow = make_tensor_like(tC_gRow);
// Generate the pred tensor
Tensor cRow = make_identity_tensor(mRow.shape());
Tensor tC_cRow = outer_partition(
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
Shape<Int<VecLength>>{},
(_0{})
);
return Callbacks<
decltype(tC_gRow), decltype(tC_rRow),
decltype(tC_cRow), ProblemShape>(
cute::move(tC_gRow),
cute::move(tC_rRow),
cute::move(tC_cRow),
problem_shape,
params_ptr
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
@ -217,7 +367,7 @@ template<
>
struct VisitorColOrScalarBroadcast {
// This struct has been modified to have a bool indicating that ptr_col is a
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast.
struct Arguments {
Element const* ptr_col = nullptr;

View File

@ -50,6 +50,25 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
}
}
void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) {
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
@ -87,6 +106,25 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
}
}
void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) {
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
@ -139,3 +177,22 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) {
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}

View File

@ -73,19 +73,63 @@ struct enable_sm89_to_sm90 : Kernel {
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using ColOrScalarLoad =
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using RowOrScalarLoad =
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using RowOrZeroLoad =
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
// it would technically work but no use case as data_ptr is never nullptr
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
return Arguments{data_ptr};
}
};
/*
@ -110,8 +154,8 @@ struct ScaledEpilogue
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
@ -131,28 +175,32 @@ struct ScaledEpilogue
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
typename EVTCompute0::Arguments evt0_compute_args{b_args};
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
return evt_compute_args;
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
protected:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
@ -164,30 +212,163 @@ struct ScaledEpilogueBias
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementD, Stride<Int<0>, Int<1>, Int<0>>>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
using BiasArgs = typename Bias::Arguments;
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
BiasArgs bias_args{static_cast<ElementD*>(bias.data_ptr()), {}};
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
typename EVTCompute0::Arguments evt0_compute_args{b_args};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzp
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
bias_args};
return evt_compute_args;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzpToken
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};

View File

@ -58,21 +58,63 @@ struct enable_sm90_or_later : Kernel {
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
template <typename T>
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
template <typename T>
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
!std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
};
/*
@ -97,8 +139,8 @@ struct ScaledEpilogue
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
@ -118,24 +160,32 @@ struct ScaledEpilogue
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ScaleA_Args = typename ScaleA::Arguments;
using ScaleB_Args = typename ScaleB::Arguments;
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
return ArgumentType{a_args, {b_args}};
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
@ -148,27 +198,160 @@ struct ScaledEpilogueBias
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzp
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
using ScaleA_Args = typename ScaleA::Arguments;
using ScaleB_Args = typename ScaleB::Arguments;
using Bias_Args = typename Bias::Arguments;
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
Bias_Args bias_args{static_cast<ElementD*>(bias.data_ptr())};
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
return ArgumentType{a_args, {b_args}, bias_args};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzpToken
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
@ -546,4 +729,23 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
}
}
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) {
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
#endif

View File

@ -29,6 +29,40 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
c10::optional<torch::Tensor> const& bias);
#endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);
#endif
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
@ -45,18 +79,20 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
return false;
}
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
int32_t major_capability;
int32_t minor_capability;
int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
0);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
0);
int32_t version_num = major_capability * 10 + minor_capability;
return version_num;
}
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
// Checks for conformality
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
@ -77,7 +113,7 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
}
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
if (version_num >= 90) {
// Hopper
@ -99,3 +135,64 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
}
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias) {
// Checks for conformality
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
// bias, azp, azp_adj are all 1d
// bias and azp_adj have n elements, azp has m elements
if (bias) {
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
}
if (azp) {
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
}
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
// azp & bias types
TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
"currently bias dtype must match output dtype ", c.dtype());
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
if (version_num >= 90) {
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
#else
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
#endif
} else if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} else if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} else {
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
}
}

View File

@ -166,13 +166,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization.
// quantization, as well as bias
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);

View File

@ -28,13 +28,16 @@ def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def rand_int8(shape: tuple, device: str = "cuda"):
return to_int8(torch.rand(shape, device=device) * 255 - 128)
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = (scale_a * (scale_b * (torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
if bias is not None:
@ -221,6 +224,121 @@ def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
use_bias)
@pytest.mark.parametrize("m", [32, 64, 128])
@pytest.mark.parametrize("n", [16, 32, 64])
@pytest.mark.parametrize("k", [64, 128, 256])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.skip
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
out_dtype: torch.dtype):
# Currently, the test is failing because folding azp into
# 16-bit bias loses too much precision
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
aq_i8 = rand_int8((m, k))
bq_i8 = rand_int8((n, k)).t()
aq_i32 = aq_i8.to(dtype=torch.int32)
bq_i32 = bq_i8.to(dtype=torch.int32)
aq_f32 = aq_i8.to(dtype=torch.float32)
bq_f32 = bq_i8.to(dtype=torch.float32)
b_dq = scale_b * bq_f32
azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
assert torch.allclose(a_dq, scale_a * aq_f32 + azp_a)
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
J = torch.ones((1, k), device="cuda", dtype=torch.float32)
azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
assert azp_bias.shape == (1, n)
assert azp_bias[0, :].shape == (n, )
baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
(aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
dtype=out_dtype, device='cuda')
out = ops.cutlass_scaled_mm(aq_i8,
bq_i8,
scale_a,
scale_b,
out_dtype=out_dtype,
bias=azp_bias[0, :])
assert torch.allclose(out, baseline_dq, rtol=1e-2, atol=1e0)
assert torch.allclose(out, baseline_q, rtol=1e-2, atol=1e0)
@pytest.mark.parametrize("m", [32, 64, 128])
@pytest.mark.parametrize("n", [16, 32, 64])
@pytest.mark.parametrize("k", [64, 128, 256])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("azp_per_token", [True, False])
def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
use_bias: bool, azp_per_token: bool):
m_azp = m if azp_per_token else 1
scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
aq_i8 = rand_int8((m, k))
aq_i32 = aq_i8.to(dtype=torch.int32)
aq_f32 = aq_i8.to(dtype=torch.float32)
bq_i8 = rand_int8((n, k)).t()
bq_i32 = bq_i8.to(dtype=torch.int32)
bq_f32 = bq_i8.to(dtype=torch.float32)
b_dq = scale_b * bq_f32
azp_a = torch.rand(
(m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
assert torch.allclose(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
if use_bias:
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
else:
bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
# int32 mm not supported on CUDA
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
# Hadamard is just the sum of the cols
azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
azp_i32 = azp_aq_i8.to(dtype=torch.int32)
func_bias = bias if use_bias else None
if azp_per_token:
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
out_dtype, azp_adj_i32, azp_i32,
func_bias)
else:
azp_with_adj_i32 = azp_i32 * azp_adj_i32
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
out_dtype, azp_with_adj_i32, None,
func_bias)
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
atol = 1e-3
assert torch.allclose(out, baseline_dq, rtol=rtol, atol=atol)
assert torch.allclose(out, baseline_q, rtol=rtol, atol=atol)
# Test working with a subset of A and B
def test_cutlass_subset():
big_m, big_n, big_k = 1024, 1024, 1024

View File

@ -241,6 +241,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype
m = a.shape[0]
n = b.shape[1]
@ -251,6 +253,28 @@ def cutlass_scaled_mm(a: torch.Tensor,
return out
def cutlass_scaled_mm_azp(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
azp_adj: torch.Tensor,
azp: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.numel(
) == b.shape[1] and bias.dtype == out_dtype
m = a.shape[0]
n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
azp, bias)
return out
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
@ -572,7 +596,7 @@ for k, v in names_and_values.items():
if isinstance(v, fn_type) \
and v.__code__.co_filename == __file__ \
and any(arg is torch.Tensor or arg == "torch.Tensor"
for arg in v.__annotations__.values()):
for arg in v.__annotations__.values()):
names_and_values_to_update[k] = hint_on_error(v)
names_and_values.update(names_and_values_to_update)