175 lines
5.5 KiB
C++
175 lines
5.5 KiB
C++
#ifndef DNNL_HELPER_HPP
|
|
#define DNNL_HELPER_HPP
|
|
|
|
#include <c10/util/BFloat16.h>
|
|
#include <c10/util/Half.h>
|
|
|
|
#include "oneapi/dnnl/dnnl.hpp"
|
|
|
|
namespace {
|
|
template <typename T>
|
|
struct DNNLType {
|
|
static constexpr dnnl::memory::data_type type =
|
|
dnnl::memory::data_type::undef;
|
|
};
|
|
|
|
template <>
|
|
struct DNNLType<int8_t> {
|
|
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
|
|
};
|
|
|
|
template <>
|
|
struct DNNLType<int32_t> {
|
|
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
|
|
};
|
|
|
|
template <>
|
|
struct DNNLType<float> {
|
|
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
|
};
|
|
|
|
template <>
|
|
struct DNNLType<c10::BFloat16> {
|
|
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
|
|
};
|
|
|
|
template <>
|
|
struct DNNLType<c10::Half> {
|
|
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
|
|
};
|
|
|
|
template <typename T>
|
|
constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
|
return DNNLType<std::decay_t<T>>::type;
|
|
}
|
|
}; // namespace
|
|
|
|
template <bool InputNoScale>
|
|
class DNNLPrimitiveHelper {
|
|
public:
|
|
// I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
|
|
// A: [M, K], row-major
|
|
// B: [K, N], column-major
|
|
// C: [M, N], row-major
|
|
// bias: [N], row-major, optional
|
|
// a_scales: [MS]
|
|
// b_scales: [NS]
|
|
// Note: Due to the limitation of oneDNN
|
|
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
|
// not supported.
|
|
template <typename OutputT, typename BiasT>
|
|
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
|
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
|
dnnl_dim_t K, const float* a_scales,
|
|
const float* b_scales, dnnl_dim_t MS,
|
|
dnnl_dim_t NS) {
|
|
auto&& OutputType = get_dnnl_type<OutputT>();
|
|
auto&& BiasType = get_dnnl_type<BiasT>();
|
|
|
|
dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
|
|
dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
|
|
dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
|
|
|
|
dnnl::primitive_attr attr;
|
|
if constexpr (!InputNoScale) {
|
|
if (MS == 1) {
|
|
// per-tensor
|
|
attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
|
} else {
|
|
// per-token
|
|
TORCH_CHECK(false, "per-token quantization is unsupported.");
|
|
}
|
|
}
|
|
|
|
if (NS == 1) {
|
|
// per-tensor
|
|
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
|
} else {
|
|
// per-channel
|
|
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
|
|
}
|
|
|
|
dnnl::matmul::primitive_desc matmul_pd;
|
|
if (bias) {
|
|
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
|
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
|
bias_md, c_md, attr);
|
|
} else {
|
|
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
|
c_md, attr);
|
|
}
|
|
dnnl::matmul matmul(matmul_pd);
|
|
|
|
auto& engine = default_engine();
|
|
|
|
dnnl::memory a_m(a_md, engine, (void*)a);
|
|
dnnl::memory b_m(b_md, engine, (void*)b);
|
|
dnnl::memory c_m(c_md, engine, (void*)c);
|
|
dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
|
|
(void*)a_scales);
|
|
dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
|
|
(void*)b_scales);
|
|
|
|
auto& stream = default_stream();
|
|
if constexpr (InputNoScale) {
|
|
if (bias) {
|
|
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
|
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
|
matmul.execute(
|
|
stream, {
|
|
{DNNL_ARG_SRC, a_m},
|
|
{DNNL_ARG_WEIGHTS, b_m},
|
|
{DNNL_ARG_BIAS, bias_m},
|
|
{DNNL_ARG_DST, c_m},
|
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
|
});
|
|
} else {
|
|
matmul.execute(
|
|
stream, {
|
|
{DNNL_ARG_SRC, a_m},
|
|
{DNNL_ARG_WEIGHTS, b_m},
|
|
{DNNL_ARG_DST, c_m},
|
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
|
});
|
|
}
|
|
} else {
|
|
if (bias) {
|
|
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
|
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
|
matmul.execute(
|
|
stream, {
|
|
{DNNL_ARG_SRC, a_m},
|
|
{DNNL_ARG_WEIGHTS, b_m},
|
|
{DNNL_ARG_BIAS, bias_m},
|
|
{DNNL_ARG_DST, c_m},
|
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
|
});
|
|
} else {
|
|
matmul.execute(
|
|
stream, {
|
|
{DNNL_ARG_SRC, a_m},
|
|
{DNNL_ARG_WEIGHTS, b_m},
|
|
{DNNL_ARG_DST, c_m},
|
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
|
});
|
|
}
|
|
}
|
|
stream.wait();
|
|
}
|
|
|
|
private:
|
|
static dnnl::engine& default_engine() {
|
|
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
|
return engine;
|
|
}
|
|
|
|
static dnnl::stream& default_stream() {
|
|
static dnnl::stream stream(default_engine());
|
|
return stream;
|
|
}
|
|
};
|
|
|
|
#endif
|