[Misc] Use scalar type to dispatch to different gptq_marlin
kernels (#7323)
This commit is contained in:
parent
1137f343aa
commit
6aa33cb2dd
@ -20,7 +20,7 @@ namespace vllm {
|
|||||||
//
|
//
|
||||||
class ScalarType {
|
class ScalarType {
|
||||||
public:
|
public:
|
||||||
enum NanRepr : int64_t {
|
enum NanRepr : uint8_t {
|
||||||
NAN_NONE = 0, // nans are not supported
|
NAN_NONE = 0, // nans are not supported
|
||||||
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
||||||
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
||||||
@ -28,33 +28,33 @@ class ScalarType {
|
|||||||
NAN_REPR_ID_MAX
|
NAN_REPR_ID_MAX
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa,
|
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
|
||||||
int64_t bias, bool finite_values_only = false,
|
int32_t bias, bool finite_values_only = false,
|
||||||
NanRepr nan_repr = NAN_IEEE_754)
|
NanRepr nan_repr = NAN_IEEE_754)
|
||||||
: exponent(exponent),
|
: exponent(exponent),
|
||||||
mantissa(mantissa),
|
mantissa(mantissa),
|
||||||
bias(bias),
|
|
||||||
signed_(signed_),
|
signed_(signed_),
|
||||||
|
bias(bias),
|
||||||
finite_values_only(finite_values_only),
|
finite_values_only(finite_values_only),
|
||||||
nan_repr(nan_repr){};
|
nan_repr(nan_repr){};
|
||||||
|
|
||||||
static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) {
|
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
||||||
return ScalarType(true, 0, size_bits - 1, bias);
|
return ScalarType(0, size_bits - 1, true, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) {
|
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
|
||||||
return ScalarType(false, 0, size_bits, bias);
|
return ScalarType(0, size_bits, false, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
// IEEE 754 compliant floating point type
|
// IEEE 754 compliant floating point type
|
||||||
static constexpr ScalarType float_IEEE754(int64_t exponent,
|
static constexpr ScalarType float_IEEE754(uint8_t exponent,
|
||||||
int64_t mantissa) {
|
uint8_t mantissa) {
|
||||||
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
||||||
return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754);
|
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
|
||||||
}
|
}
|
||||||
|
|
||||||
// IEEE 754 non-compliant floating point type
|
// IEEE 754 non-compliant floating point type
|
||||||
static constexpr ScalarType float_(int64_t exponent, int64_t mantissa,
|
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
|
||||||
bool finite_values_only,
|
bool finite_values_only,
|
||||||
NanRepr nan_repr) {
|
NanRepr nan_repr) {
|
||||||
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
||||||
@ -62,36 +62,121 @@ class ScalarType {
|
|||||||
TORCH_CHECK(nan_repr != NAN_IEEE_754,
|
TORCH_CHECK(nan_repr != NAN_IEEE_754,
|
||||||
"use `float_IEEE754` constructor for floating point types that "
|
"use `float_IEEE754` constructor for floating point types that "
|
||||||
"follow IEEE 754 conventions");
|
"follow IEEE 754 conventions");
|
||||||
return ScalarType(true, exponent, mantissa, 0, finite_values_only,
|
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
|
||||||
nan_repr);
|
nan_repr);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t const exponent; // size of the exponent field (0 for integer types)
|
uint8_t const exponent; // size of the exponent field (0 for integer types)
|
||||||
int64_t const mantissa; // size of the mantissa field (size of the integer
|
uint8_t const mantissa; // size of the mantissa field (size of the integer
|
||||||
// excluding the sign bit for integer types)
|
// excluding the sign bit for integer types)
|
||||||
int64_t const bias; // stored values equal value + bias,
|
|
||||||
// used for quantized type
|
|
||||||
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
||||||
// sign bit)
|
// sign bit)
|
||||||
|
int32_t const bias; // stored values equal value + bias,
|
||||||
|
// used for quantized type
|
||||||
|
|
||||||
// Extra Floating point info
|
// Extra Floating point info
|
||||||
bool const finite_values_only; // i.e. no +/-inf if true
|
bool const finite_values_only; // i.e. no +/-inf if true
|
||||||
NanRepr const nan_repr; // how NaNs are represented
|
NanRepr const nan_repr; // how NaNs are represented
|
||||||
// (not applicable for integer types)
|
// (not applicable for integer types)
|
||||||
|
|
||||||
int64_t size_bits() const { return mantissa + exponent + is_signed(); }
|
using Id = int64_t;
|
||||||
bool is_signed() const { return signed_; }
|
|
||||||
bool is_integer() const { return exponent == 0; }
|
private:
|
||||||
bool is_floating_point() const { return exponent > 0; }
|
// Field size in id
|
||||||
bool is_ieee_754() const {
|
template <typename T_>
|
||||||
|
static constexpr size_t member_id_field_width() {
|
||||||
|
using T = std::decay_t<T_>;
|
||||||
|
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Fn, typename Init, typename Member, typename... Rest>
|
||||||
|
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
|
||||||
|
Rest... rest) {
|
||||||
|
auto new_val = f(val, member);
|
||||||
|
if constexpr (sizeof...(rest) > 0) {
|
||||||
|
return reduce_members_helper(f, new_val, rest...);
|
||||||
|
} else {
|
||||||
|
return new_val;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Fn, typename Init>
|
||||||
|
constexpr auto reduce_members(Fn f, Init init) const {
|
||||||
|
// Should be in constructor order for `from_id`
|
||||||
|
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
|
||||||
|
finite_values_only, nan_repr);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Fn, typename Init>
|
||||||
|
static constexpr auto reduce_member_types(Fn f, Init init) {
|
||||||
|
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
|
||||||
|
return dummy_type.reduce_members(f, init);
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr auto id_size_bits() {
|
||||||
|
return reduce_member_types(
|
||||||
|
[](int acc, auto member) -> int {
|
||||||
|
return acc + member_id_field_width<decltype(member)>();
|
||||||
|
},
|
||||||
|
0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// unique id for this scalar type that can be computed at compile time for
|
||||||
|
// c++17 template specialization this is not needed once we migrate to
|
||||||
|
// c++20 and can pass literal classes as template parameters
|
||||||
|
constexpr Id id() const {
|
||||||
|
static_assert(id_size_bits() <= sizeof(Id) * 8,
|
||||||
|
"ScalarType id is too large to be stored");
|
||||||
|
|
||||||
|
auto or_and_advance = [](std::pair<Id, uint32_t> result,
|
||||||
|
auto member) -> std::pair<Id, uint32_t> {
|
||||||
|
auto [id, bit_offset] = result;
|
||||||
|
auto constexpr bits = member_id_field_width<decltype(member)>();
|
||||||
|
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
|
||||||
|
<< bit_offset,
|
||||||
|
bit_offset + bits};
|
||||||
|
};
|
||||||
|
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a ScalarType from an id, for c++17 template specialization,
|
||||||
|
// this is not needed once we migrate to c++20 and can pass literal
|
||||||
|
// classes as template parameters
|
||||||
|
static constexpr ScalarType from_id(Id id) {
|
||||||
|
auto extract_and_advance = [id](auto result, auto member) {
|
||||||
|
using T = decltype(member);
|
||||||
|
auto [tuple, bit_offset] = result;
|
||||||
|
auto constexpr bits = member_id_field_width<T>();
|
||||||
|
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
|
||||||
|
((uint64_t(1) << bits) - 1));
|
||||||
|
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
|
||||||
|
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
|
||||||
|
std::pair<std::tuple<>, int>{});
|
||||||
|
return std::apply([](auto... args) { return ScalarType(args...); },
|
||||||
|
tuple_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int64_t size_bits() const {
|
||||||
|
return mantissa + exponent + is_signed();
|
||||||
|
}
|
||||||
|
constexpr bool is_signed() const { return signed_; }
|
||||||
|
constexpr bool is_integer() const { return exponent == 0; }
|
||||||
|
constexpr bool is_floating_point() const { return exponent > 0; }
|
||||||
|
constexpr bool is_ieee_754() const {
|
||||||
return is_floating_point() && finite_values_only == false &&
|
return is_floating_point() && finite_values_only == false &&
|
||||||
nan_repr == NAN_IEEE_754;
|
nan_repr == NAN_IEEE_754;
|
||||||
}
|
}
|
||||||
bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; }
|
constexpr bool has_nans() const {
|
||||||
bool has_infs() const {
|
return is_floating_point() && nan_repr != NAN_NONE;
|
||||||
|
}
|
||||||
|
constexpr bool has_infs() const {
|
||||||
return is_floating_point() && finite_values_only == false;
|
return is_floating_point() && finite_values_only == false;
|
||||||
}
|
}
|
||||||
bool has_bias() const { return bias != 0; }
|
constexpr bool has_bias() const { return bias != 0; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
double _floating_point_max() const {
|
double _floating_point_max() const {
|
||||||
@ -131,7 +216,7 @@ class ScalarType {
|
|||||||
return *reinterpret_cast<double*>(&double_raw);
|
return *reinterpret_cast<double*>(&double_raw);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::variant<int64_t, double> _raw_max() const {
|
constexpr std::variant<int64_t, double> _raw_max() const {
|
||||||
if (is_floating_point()) {
|
if (is_floating_point()) {
|
||||||
return {_floating_point_max()};
|
return {_floating_point_max()};
|
||||||
} else {
|
} else {
|
||||||
@ -141,7 +226,7 @@ class ScalarType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::variant<int64_t, double> _raw_min() const {
|
constexpr std::variant<int64_t, double> _raw_min() const {
|
||||||
if (is_floating_point()) {
|
if (is_floating_point()) {
|
||||||
TORCH_CHECK(is_signed(),
|
TORCH_CHECK(is_signed(),
|
||||||
"We currently assume all floating point types are signed");
|
"We currently assume all floating point types are signed");
|
||||||
@ -168,7 +253,7 @@ class ScalarType {
|
|||||||
public:
|
public:
|
||||||
// Max representable value for this scalar type.
|
// Max representable value for this scalar type.
|
||||||
// (accounting for bias if there is one)
|
// (accounting for bias if there is one)
|
||||||
std::variant<int64_t, double> max() const {
|
constexpr std::variant<int64_t, double> max() const {
|
||||||
return std::visit(
|
return std::visit(
|
||||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||||
_raw_max());
|
_raw_max());
|
||||||
@ -176,7 +261,7 @@ class ScalarType {
|
|||||||
|
|
||||||
// Min representable value for this scalar type.
|
// Min representable value for this scalar type.
|
||||||
// (accounting for bias if there is one)
|
// (accounting for bias if there is one)
|
||||||
std::variant<int64_t, double> min() const {
|
constexpr std::variant<int64_t, double> min() const {
|
||||||
return std::visit(
|
return std::visit(
|
||||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||||
_raw_min());
|
_raw_min());
|
||||||
@ -215,7 +300,7 @@ class ScalarType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(ScalarType const& other) const {
|
constexpr bool operator==(ScalarType const& other) const {
|
||||||
return mantissa == other.mantissa && exponent == other.exponent &&
|
return mantissa == other.mantissa && exponent == other.exponent &&
|
||||||
bias == other.bias && signed_ == other.signed_ &&
|
bias == other.bias && signed_ == other.signed_ &&
|
||||||
finite_values_only == other.finite_values_only &&
|
finite_values_only == other.finite_values_only &&
|
||||||
@ -240,23 +325,59 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
|||||||
using Self = ScalarTypeTorch;
|
using Self = ScalarTypeTorch;
|
||||||
using SelfPtr = c10::intrusive_ptr<Self>;
|
using SelfPtr = c10::intrusive_ptr<Self>;
|
||||||
|
|
||||||
|
static void check_size_bits(int64_t size_bits, bool signed_) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
size_bits <=
|
||||||
|
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
|
||||||
|
"size_bits bit width is too large to be represented");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void check_bias(int64_t bias) {
|
||||||
|
using Bias = decltype(std::declval<Self>().bias);
|
||||||
|
TORCH_CHECK(bias <= std::numeric_limits<Bias>::max() &&
|
||||||
|
bias >= std::numeric_limits<Bias>::min(),
|
||||||
|
"bias too large or small to be represented");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void check_exponent(int64_t exponent) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
exponent <=
|
||||||
|
std::numeric_limits<decltype(std::declval<Self>().exponent)>::max(),
|
||||||
|
"exponent bit width is too large to be represented");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void check_mantissa(int64_t mantissa) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
mantissa <=
|
||||||
|
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
|
||||||
|
"mantissa bit width is too large to be represented");
|
||||||
|
}
|
||||||
|
|
||||||
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
|
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
|
||||||
|
check_size_bits(size_bits, true);
|
||||||
|
check_bias(bias.value_or(0));
|
||||||
return c10::make_intrusive<Self>(
|
return c10::make_intrusive<Self>(
|
||||||
ScalarType::int_(size_bits, bias.value_or(0)));
|
ScalarType::int_(size_bits, bias.value_or(0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
|
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
|
||||||
|
check_size_bits(size_bits, true);
|
||||||
|
check_bias(bias.value_or(0));
|
||||||
return c10::make_intrusive<Self>(
|
return c10::make_intrusive<Self>(
|
||||||
ScalarType::uint(size_bits, bias.value_or(0)));
|
ScalarType::uint(size_bits, bias.value_or(0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
|
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
|
||||||
|
check_mantissa(mantissa);
|
||||||
|
check_exponent(exponent);
|
||||||
return c10::make_intrusive<Self>(
|
return c10::make_intrusive<Self>(
|
||||||
ScalarType::float_IEEE754(exponent, mantissa));
|
ScalarType::float_IEEE754(exponent, mantissa));
|
||||||
}
|
}
|
||||||
|
|
||||||
static SelfPtr float_(int64_t exponent, int64_t mantissa,
|
static SelfPtr float_(int64_t exponent, int64_t mantissa,
|
||||||
bool finite_values_only, int64_t nan_repr) {
|
bool finite_values_only, int64_t nan_repr) {
|
||||||
|
check_mantissa(mantissa);
|
||||||
|
check_exponent(exponent);
|
||||||
return c10::make_intrusive<Self>(ScalarType::float_(
|
return c10::make_intrusive<Self>(ScalarType::float_(
|
||||||
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
|
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
|
||||||
}
|
}
|
||||||
@ -264,7 +385,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
static void bind_readonly_property(torch::class_<Self>& cls,
|
static void bind_readonly_property(torch::class_<Self>& cls,
|
||||||
std::string const& name, T Base::*field) {
|
std::string const& name, T Base::*field) {
|
||||||
auto getter_func = [field = std::move(field)](SelfPtr const& self) {
|
auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) {
|
||||||
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
|
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
|
||||||
return (self.get()->*field)();
|
return (self.get()->*field)();
|
||||||
} else {
|
} else {
|
||||||
@ -272,6 +393,18 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto getter_func = [field = std::move(field),
|
||||||
|
getter_func_helper = std::move(getter_func_helper)](
|
||||||
|
SelfPtr const& self) {
|
||||||
|
auto val = getter_func_helper(self);
|
||||||
|
// upconvert uint8_t, int32_t etc. to int64_t for python
|
||||||
|
if constexpr (std::is_integral_v<T>) {
|
||||||
|
return static_cast<int64_t>(val);
|
||||||
|
} else {
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
cls.def_property(name, getter_func);
|
cls.def_property(name, getter_func);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -340,6 +473,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using ScalarTypeId = int64_t;
|
||||||
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
|
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
|
||||||
|
|
||||||
// "rust style" names generally following:
|
// "rust style" names generally following:
|
||||||
@ -379,4 +513,5 @@ static inline constexpr auto kHalf = kFE5M10;
|
|||||||
static inline constexpr auto kFloat16 = kHalf;
|
static inline constexpr auto kFloat16 = kHalf;
|
||||||
static inline constexpr auto kBFloat16 = kFE8M7;
|
static inline constexpr auto kBFloat16 = kFE8M7;
|
||||||
|
|
||||||
|
static inline constexpr auto kFloat16Id = kFloat16.id();
|
||||||
}; // namespace vllm
|
}; // namespace vllm
|
||||||
|
@ -43,7 +43,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
|||||||
int size_k, int block_rows) {}
|
int size_k, int block_rows) {}
|
||||||
|
|
||||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||||
const int num_bits, // number of bits used for weights
|
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the
|
// dimension (batchsize) of the
|
||||||
@ -151,20 +151,21 @@ __device__ inline uint32_t prmt(uint32_t a) {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
template <typename scalar_t, vllm::ScalarTypeId w_type_id>
|
||||||
// values. We mostly follow the strategy in the link below, with some small
|
__device__ inline typename ScalarType<scalar_t>::FragB dequant(int q);
|
||||||
// changes:
|
|
||||||
|
//
|
||||||
|
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
||||||
|
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
||||||
|
// with some small changes:
|
||||||
// - FP16:
|
// - FP16:
|
||||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||||
// - BF16:
|
// - BF16:
|
||||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||||
template <typename scalar_t>
|
//
|
||||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
|
|
||||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
|
__device__ inline typename ScalarType<half>::FragB
|
||||||
|
dequant<half, vllm::kU4B8.id()>(int q) {
|
||||||
const int LO = 0x000f000f;
|
const int LO = 0x000f000f;
|
||||||
const int HI = 0x00f000f0;
|
const int HI = 0x00f000f0;
|
||||||
const int EX = 0x64006400;
|
const int EX = 0x64006400;
|
||||||
@ -187,7 +188,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||||
dequant_4bit<nv_bfloat16>(int q) {
|
dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
|
||||||
static constexpr uint32_t MASK = 0x000f000f;
|
static constexpr uint32_t MASK = 0x000f000f;
|
||||||
static constexpr uint32_t EX = 0x43004300;
|
static constexpr uint32_t EX = 0x43004300;
|
||||||
|
|
||||||
@ -210,19 +211,64 @@ dequant_4bit<nv_bfloat16>(int q) {
|
|||||||
return frag_b;
|
return frag_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline typename ScalarType<half>::FragB
|
||||||
|
dequant<half, vllm::kU4.id()>(int q) {
|
||||||
|
const int LO = 0x000f000f;
|
||||||
|
const int HI = 0x00f000f0;
|
||||||
|
const int EX = 0x64006400;
|
||||||
|
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||||
|
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||||
|
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||||
|
|
||||||
|
const int SUB = 0x64006400;
|
||||||
|
const int MUL = 0x2c002c00;
|
||||||
|
const int ADD = 0xd400d400;
|
||||||
|
typename ScalarType<half>::FragB frag_b;
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
|
*reinterpret_cast<const half2*>(&SUB));
|
||||||
|
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||||
|
*reinterpret_cast<const half2*>(&MUL),
|
||||||
|
*reinterpret_cast<const half2*>(&ADD));
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||||
|
dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
|
||||||
|
static constexpr uint32_t MASK = 0x000f000f;
|
||||||
|
static constexpr uint32_t EX = 0x43004300;
|
||||||
|
|
||||||
|
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||||
|
|
||||||
|
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||||
|
q >>= 4;
|
||||||
|
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||||
|
|
||||||
|
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||||
|
static constexpr uint32_t MUL = 0x3F803F80;
|
||||||
|
static constexpr uint32_t ADD = 0xC300C300;
|
||||||
|
|
||||||
|
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
||||||
// bf16 Reference:
|
// bf16 Reference:
|
||||||
// - FP16:
|
// - FP16:
|
||||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||||
// - BF16:
|
// - BF16:
|
||||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||||
template <typename scalar_t>
|
//
|
||||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
|
|
||||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
|
__device__ inline typename ScalarType<half>::FragB
|
||||||
|
dequant<half, vllm::kU8B128.id()>(int q) {
|
||||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||||
@ -242,7 +288,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||||
dequant_8bit<nv_bfloat16>(int q) {
|
dequant<nv_bfloat16, vllm::kU8B128.id()>(int q) {
|
||||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||||
|
|
||||||
float fp32_intermediates[4];
|
float fp32_intermediates[4];
|
||||||
@ -269,68 +315,9 @@ dequant_8bit<nv_bfloat16>(int q) {
|
|||||||
return frag_b;
|
return frag_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Zero-point dequantizers
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit_zp(int q) {
|
|
||||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ inline typename ScalarType<half>::FragB dequant_4bit_zp<half>(
|
__device__ inline typename ScalarType<half>::FragB
|
||||||
int q) {
|
dequant<half, vllm::kU8.id()>(int q) {
|
||||||
const int LO = 0x000f000f;
|
|
||||||
const int HI = 0x00f000f0;
|
|
||||||
const int EX = 0x64006400;
|
|
||||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
|
||||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
|
||||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
|
||||||
|
|
||||||
const int SUB = 0x64006400;
|
|
||||||
const int MUL = 0x2c002c00;
|
|
||||||
const int ADD = 0xd400d400;
|
|
||||||
typename ScalarType<half>::FragB frag_b;
|
|
||||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
|
||||||
*reinterpret_cast<const half2*>(&SUB));
|
|
||||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
|
||||||
*reinterpret_cast<const half2*>(&MUL),
|
|
||||||
*reinterpret_cast<const half2*>(&ADD));
|
|
||||||
return frag_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
|
||||||
dequant_4bit_zp<nv_bfloat16>(int q) {
|
|
||||||
static constexpr uint32_t MASK = 0x000f000f;
|
|
||||||
static constexpr uint32_t EX = 0x43004300;
|
|
||||||
|
|
||||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
|
||||||
|
|
||||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
|
||||||
q >>= 4;
|
|
||||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
|
||||||
|
|
||||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
|
||||||
static constexpr uint32_t MUL = 0x3F803F80;
|
|
||||||
static constexpr uint32_t ADD = 0xC300C300;
|
|
||||||
|
|
||||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
|
||||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
|
||||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
|
||||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
|
||||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
|
||||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
|
||||||
return frag_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit_zp(int q) {
|
|
||||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
__device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
|
|
||||||
int q) {
|
|
||||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||||
@ -350,7 +337,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||||
dequant_8bit_zp<nv_bfloat16>(int q) {
|
dequant<nv_bfloat16, vllm::kU8.id()>(int q) {
|
||||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||||
|
|
||||||
float fp32_intermediates[4];
|
float fp32_intermediates[4];
|
||||||
@ -518,7 +505,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||||
const int num_bits, // number of bits used for weights
|
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the
|
// dimension (batchsize) of the
|
||||||
@ -568,7 +555,9 @@ __global__ void Marlin(
|
|||||||
using FragS = typename ScalarType<scalar_t>::FragS;
|
using FragS = typename ScalarType<scalar_t>::FragS;
|
||||||
using FragZP = typename ScalarType<scalar_t>::FragZP;
|
using FragZP = typename ScalarType<scalar_t>::FragZP;
|
||||||
|
|
||||||
constexpr int pack_factor = 32 / num_bits;
|
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||||
|
|
||||||
|
constexpr int pack_factor = 32 / w_type.size_bits();
|
||||||
|
|
||||||
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
|
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
|
||||||
// better partitioning with less reductions
|
// better partitioning with less reductions
|
||||||
@ -670,7 +659,7 @@ __global__ void Marlin(
|
|||||||
// B sizes/strides
|
// B sizes/strides
|
||||||
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
|
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
|
||||||
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
|
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
|
||||||
constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
|
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
|
||||||
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
|
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
|
||||||
|
|
||||||
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
|
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
|
||||||
@ -1186,19 +1175,20 @@ __global__ void Marlin(
|
|||||||
if constexpr (has_zp) {
|
if constexpr (has_zp) {
|
||||||
FragB frag_zp_0;
|
FragB frag_zp_0;
|
||||||
FragB frag_zp_1;
|
FragB frag_zp_1;
|
||||||
if constexpr (num_bits == 4) {
|
int zp_quant_0, zp_quant_1;
|
||||||
int zp_quant = frag_qzp[k % 2][0];
|
|
||||||
int zp_quant_shift = zp_quant >> 8;
|
|
||||||
frag_zp_0 = dequant_4bit_zp<scalar_t>(zp_quant);
|
|
||||||
frag_zp_1 = dequant_4bit_zp<scalar_t>(zp_quant_shift);
|
|
||||||
|
|
||||||
|
if constexpr (w_type.size_bits() == 4) {
|
||||||
|
zp_quant_0 = frag_qzp[k % 2][0];
|
||||||
|
zp_quant_1 = zp_quant_0 >> 8;
|
||||||
} else {
|
} else {
|
||||||
int zp_quant_0 = frag_qzp[k % 2][0];
|
static_assert(w_type.size_bits() == 8);
|
||||||
int zp_quant_1 = frag_qzp[k % 2][1];
|
zp_quant_0 = frag_qzp[k % 2][0];
|
||||||
frag_zp_0 = dequant_8bit_zp<scalar_t>(zp_quant_0);
|
zp_quant_1 = frag_qzp[k % 2][1];
|
||||||
frag_zp_1 = dequant_8bit_zp<scalar_t>(zp_quant_1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
frag_zp_0 = dequant<scalar_t, w_type_id>(zp_quant_0);
|
||||||
|
frag_zp_1 = dequant<scalar_t, w_type_id>(zp_quant_1);
|
||||||
|
|
||||||
frag_zp[0] = frag_zp_0[0];
|
frag_zp[0] = frag_zp_0[0];
|
||||||
frag_zp[1] = frag_zp_0[1];
|
frag_zp[1] = frag_zp_0[1];
|
||||||
frag_zp[2] = frag_zp_1[0];
|
frag_zp[2] = frag_zp_1[0];
|
||||||
@ -1211,32 +1201,20 @@ __global__ void Marlin(
|
|||||||
for (int j = 0; j < 4; j++) {
|
for (int j = 0; j < 4; j++) {
|
||||||
FragB frag_b0;
|
FragB frag_b0;
|
||||||
FragB frag_b1;
|
FragB frag_b1;
|
||||||
if constexpr (num_bits == 4) {
|
int b_quant_0, b_quant_1;
|
||||||
int b_quant = frag_b_quant[k % 2][0][j];
|
|
||||||
int b_quant_shift = b_quant >> 8;
|
|
||||||
|
|
||||||
if constexpr (has_zp) {
|
|
||||||
frag_b0 = dequant_4bit_zp<scalar_t>(b_quant);
|
|
||||||
frag_b1 = dequant_4bit_zp<scalar_t>(b_quant_shift);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
frag_b0 = dequant_4bit<scalar_t>(b_quant);
|
|
||||||
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if constexpr (w_type.size_bits() == 4) {
|
||||||
|
b_quant_0 = frag_b_quant[k % 2][0][j];
|
||||||
|
b_quant_1 = b_quant_0 >> 8;
|
||||||
} else {
|
} else {
|
||||||
|
static_assert(w_type.size_bits() == 8);
|
||||||
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
||||||
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
||||||
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
||||||
|
}
|
||||||
|
|
||||||
if constexpr (has_zp) {
|
frag_b0 = dequant<scalar_t, w_type_id>(b_quant_0);
|
||||||
frag_b0 = dequant_8bit_zp<scalar_t>(b_quant_0);
|
frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
|
||||||
frag_b1 = dequant_8bit_zp<scalar_t>(b_quant_1);
|
|
||||||
} else {
|
|
||||||
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
|
|
||||||
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply zero-point to frag_b0
|
// Apply zero-point to frag_b0
|
||||||
if constexpr (has_zp) {
|
if constexpr (has_zp) {
|
||||||
@ -1477,7 +1455,8 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
// For per-column quantization we finally apply the scale here (only for
|
// For per-column quantization we finally apply the scale here (only for
|
||||||
// 4-bit)
|
// 4-bit)
|
||||||
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {
|
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||||
|
w_type.size_bits() == 4) {
|
||||||
res = __hmul2(res, s[0]);
|
res = __hmul2(res, s[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1605,7 +1584,7 @@ __global__ void Marlin(
|
|||||||
// For per-column scales, we only fetch them here in the final step before
|
// For per-column scales, we only fetch them here in the final step before
|
||||||
// write-out
|
// write-out
|
||||||
if constexpr (!has_act_order && group_blocks == -1) {
|
if constexpr (!has_act_order && group_blocks == -1) {
|
||||||
if constexpr (num_bits == 8) {
|
if constexpr (w_type.size_bits() == 8) {
|
||||||
if (s_sh_wr_pred) {
|
if (s_sh_wr_pred) {
|
||||||
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||||
}
|
}
|
||||||
@ -1622,7 +1601,7 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
thread_block_reduce();
|
thread_block_reduce();
|
||||||
if constexpr (!has_act_order && group_blocks == -1) {
|
if constexpr (!has_act_order && group_blocks == -1) {
|
||||||
if constexpr (num_bits == 8) {
|
if constexpr (w_type.size_bits() == 8) {
|
||||||
cp_async_wait<0>();
|
cp_async_wait<0>();
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||||
@ -1645,7 +1624,8 @@ __global__ void Marlin(
|
|||||||
// For 8-bit channelwise, we apply the scale before the global reduction
|
// For 8-bit channelwise, we apply the scale before the global reduction
|
||||||
// that converts the fp32 results to fp16 (so that we avoid possible
|
// that converts the fp32 results to fp16 (so that we avoid possible
|
||||||
// overflow in fp16)
|
// overflow in fp16)
|
||||||
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {
|
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||||
|
w_type.size_bits() == 8) {
|
||||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks; i++) {
|
for (int i = 0; i < thread_m_blocks; i++) {
|
||||||
@ -1714,20 +1694,19 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
#define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||||
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
|
||||||
NUM_THREADS) \
|
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||||
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
|
||||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
||||||
cudaFuncSetAttribute( \
|
cudaFuncSetAttribute( \
|
||||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||||
HAS_ZP, GROUP_BLOCKS>, \
|
HAS_ZP, GROUP_BLOCKS>, \
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||||
HAS_ZP, GROUP_BLOCKS> \
|
HAS_ZP, GROUP_BLOCKS> \
|
||||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||||
@ -1923,52 +1902,52 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
|||||||
return exec_config_t{0, {-1, -1, -1}};
|
return exec_config_t{0, {-1, -1, -1}};
|
||||||
}
|
}
|
||||||
|
|
||||||
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
|
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
|
||||||
|
|
||||||
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||||
@ -2113,23 +2092,23 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
|
|
||||||
if (false) {
|
if (false) {
|
||||||
}
|
}
|
||||||
GPTQ_CALL_IF(4, 16, 4, 256)
|
GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
|
||||||
GPTQ_CALL_IF(4, 8, 8, 256)
|
GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256)
|
||||||
GPTQ_CALL_IF(4, 8, 4, 128)
|
GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128)
|
||||||
GPTQ_CALL_IF(4, 4, 8, 128)
|
GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128)
|
||||||
GPTQ_CALL_IF(8, 16, 4, 256)
|
GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256)
|
||||||
GPTQ_CALL_IF(8, 8, 8, 256)
|
GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256)
|
||||||
GPTQ_CALL_IF(8, 8, 4, 128)
|
GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128)
|
||||||
GPTQ_CALL_IF(8, 4, 8, 128)
|
GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128)
|
||||||
|
|
||||||
AWQ_CALL_IF(4, 16, 4, 256)
|
AWQ_CALL_IF(vllm::kU4, 16, 4, 256)
|
||||||
AWQ_CALL_IF(4, 8, 8, 256)
|
AWQ_CALL_IF(vllm::kU4, 8, 8, 256)
|
||||||
AWQ_CALL_IF(4, 8, 4, 128)
|
AWQ_CALL_IF(vllm::kU4, 8, 4, 128)
|
||||||
AWQ_CALL_IF(4, 4, 8, 128)
|
AWQ_CALL_IF(vllm::kU4, 4, 8, 128)
|
||||||
AWQ_CALL_IF(8, 16, 4, 256)
|
AWQ_CALL_IF(vllm::kU8, 16, 4, 256)
|
||||||
AWQ_CALL_IF(8, 8, 8, 256)
|
AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
|
||||||
AWQ_CALL_IF(8, 8, 4, 128)
|
AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
|
||||||
AWQ_CALL_IF(8, 4, 8, 128)
|
AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
|
||||||
else {
|
else {
|
||||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||||
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user