[Misc] Use scalar type to dispatch to different gptq_marlin kernels (#7323)

This commit is contained in:
Lucas Wilkinson 2024-08-12 14:40:13 -04:00 committed by GitHub
parent 1137f343aa
commit 6aa33cb2dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 332 additions and 218 deletions

View File

@ -20,7 +20,7 @@ namespace vllm {
// //
class ScalarType { class ScalarType {
public: public:
enum NanRepr : int64_t { enum NanRepr : uint8_t {
NAN_NONE = 0, // nans are not supported NAN_NONE = 0, // nans are not supported
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
@ -28,33 +28,33 @@ class ScalarType {
NAN_REPR_ID_MAX NAN_REPR_ID_MAX
}; };
constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa, constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
int64_t bias, bool finite_values_only = false, int32_t bias, bool finite_values_only = false,
NanRepr nan_repr = NAN_IEEE_754) NanRepr nan_repr = NAN_IEEE_754)
: exponent(exponent), : exponent(exponent),
mantissa(mantissa), mantissa(mantissa),
bias(bias),
signed_(signed_), signed_(signed_),
bias(bias),
finite_values_only(finite_values_only), finite_values_only(finite_values_only),
nan_repr(nan_repr){}; nan_repr(nan_repr){};
static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) { static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(true, 0, size_bits - 1, bias); return ScalarType(0, size_bits - 1, true, bias);
} }
static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) { static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(false, 0, size_bits, bias); return ScalarType(0, size_bits, false, bias);
} }
// IEEE 754 compliant floating point type // IEEE 754 compliant floating point type
static constexpr ScalarType float_IEEE754(int64_t exponent, static constexpr ScalarType float_IEEE754(uint8_t exponent,
int64_t mantissa) { uint8_t mantissa) {
TORCH_CHECK(mantissa > 0 && exponent > 0); TORCH_CHECK(mantissa > 0 && exponent > 0);
return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754); return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
} }
// IEEE 754 non-compliant floating point type // IEEE 754 non-compliant floating point type
static constexpr ScalarType float_(int64_t exponent, int64_t mantissa, static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
bool finite_values_only, bool finite_values_only,
NanRepr nan_repr) { NanRepr nan_repr) {
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
@ -62,36 +62,121 @@ class ScalarType {
TORCH_CHECK(nan_repr != NAN_IEEE_754, TORCH_CHECK(nan_repr != NAN_IEEE_754,
"use `float_IEEE754` constructor for floating point types that " "use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions"); "follow IEEE 754 conventions");
return ScalarType(true, exponent, mantissa, 0, finite_values_only, return ScalarType(exponent, mantissa, true, 0, finite_values_only,
nan_repr); nan_repr);
} }
int64_t const exponent; // size of the exponent field (0 for integer types) uint8_t const exponent; // size of the exponent field (0 for integer types)
int64_t const mantissa; // size of the mantissa field (size of the integer uint8_t const mantissa; // size of the mantissa field (size of the integer
// excluding the sign bit for integer types) // excluding the sign bit for integer types)
int64_t const bias; // stored values equal value + bias,
// used for quantized type
bool const signed_; // flag if the type supports negative numbers (i.e. has a bool const signed_; // flag if the type supports negative numbers (i.e. has a
// sign bit) // sign bit)
int32_t const bias; // stored values equal value + bias,
// used for quantized type
// Extra Floating point info // Extra Floating point info
bool const finite_values_only; // i.e. no +/-inf if true bool const finite_values_only; // i.e. no +/-inf if true
NanRepr const nan_repr; // how NaNs are represented NanRepr const nan_repr; // how NaNs are represented
// (not applicable for integer types) // (not applicable for integer types)
int64_t size_bits() const { return mantissa + exponent + is_signed(); } using Id = int64_t;
bool is_signed() const { return signed_; }
bool is_integer() const { return exponent == 0; } private:
bool is_floating_point() const { return exponent > 0; } // Field size in id
bool is_ieee_754() const { template <typename T_>
static constexpr size_t member_id_field_width() {
using T = std::decay_t<T_>;
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
}
template <typename Fn, typename Init, typename Member, typename... Rest>
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
Rest... rest) {
auto new_val = f(val, member);
if constexpr (sizeof...(rest) > 0) {
return reduce_members_helper(f, new_val, rest...);
} else {
return new_val;
};
}
template <typename Fn, typename Init>
constexpr auto reduce_members(Fn f, Init init) const {
// Should be in constructor order for `from_id`
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
finite_values_only, nan_repr);
};
template <typename Fn, typename Init>
static constexpr auto reduce_member_types(Fn f, Init init) {
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
return dummy_type.reduce_members(f, init);
};
static constexpr auto id_size_bits() {
return reduce_member_types(
[](int acc, auto member) -> int {
return acc + member_id_field_width<decltype(member)>();
},
0);
}
public:
// unique id for this scalar type that can be computed at compile time for
// c++17 template specialization this is not needed once we migrate to
// c++20 and can pass literal classes as template parameters
constexpr Id id() const {
static_assert(id_size_bits() <= sizeof(Id) * 8,
"ScalarType id is too large to be stored");
auto or_and_advance = [](std::pair<Id, uint32_t> result,
auto member) -> std::pair<Id, uint32_t> {
auto [id, bit_offset] = result;
auto constexpr bits = member_id_field_width<decltype(member)>();
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
<< bit_offset,
bit_offset + bits};
};
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
}
// create a ScalarType from an id, for c++17 template specialization,
// this is not needed once we migrate to c++20 and can pass literal
// classes as template parameters
static constexpr ScalarType from_id(Id id) {
auto extract_and_advance = [id](auto result, auto member) {
using T = decltype(member);
auto [tuple, bit_offset] = result;
auto constexpr bits = member_id_field_width<T>();
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
((uint64_t(1) << bits) - 1));
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
};
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
std::pair<std::tuple<>, int>{});
return std::apply([](auto... args) { return ScalarType(args...); },
tuple_args);
}
constexpr int64_t size_bits() const {
return mantissa + exponent + is_signed();
}
constexpr bool is_signed() const { return signed_; }
constexpr bool is_integer() const { return exponent == 0; }
constexpr bool is_floating_point() const { return exponent > 0; }
constexpr bool is_ieee_754() const {
return is_floating_point() && finite_values_only == false && return is_floating_point() && finite_values_only == false &&
nan_repr == NAN_IEEE_754; nan_repr == NAN_IEEE_754;
} }
bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; } constexpr bool has_nans() const {
bool has_infs() const { return is_floating_point() && nan_repr != NAN_NONE;
}
constexpr bool has_infs() const {
return is_floating_point() && finite_values_only == false; return is_floating_point() && finite_values_only == false;
} }
bool has_bias() const { return bias != 0; } constexpr bool has_bias() const { return bias != 0; }
private: private:
double _floating_point_max() const { double _floating_point_max() const {
@ -131,7 +216,7 @@ class ScalarType {
return *reinterpret_cast<double*>(&double_raw); return *reinterpret_cast<double*>(&double_raw);
} }
std::variant<int64_t, double> _raw_max() const { constexpr std::variant<int64_t, double> _raw_max() const {
if (is_floating_point()) { if (is_floating_point()) {
return {_floating_point_max()}; return {_floating_point_max()};
} else { } else {
@ -141,7 +226,7 @@ class ScalarType {
} }
} }
std::variant<int64_t, double> _raw_min() const { constexpr std::variant<int64_t, double> _raw_min() const {
if (is_floating_point()) { if (is_floating_point()) {
TORCH_CHECK(is_signed(), TORCH_CHECK(is_signed(),
"We currently assume all floating point types are signed"); "We currently assume all floating point types are signed");
@ -168,7 +253,7 @@ class ScalarType {
public: public:
// Max representable value for this scalar type. // Max representable value for this scalar type.
// (accounting for bias if there is one) // (accounting for bias if there is one)
std::variant<int64_t, double> max() const { constexpr std::variant<int64_t, double> max() const {
return std::visit( return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, [this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_max()); _raw_max());
@ -176,7 +261,7 @@ class ScalarType {
// Min representable value for this scalar type. // Min representable value for this scalar type.
// (accounting for bias if there is one) // (accounting for bias if there is one)
std::variant<int64_t, double> min() const { constexpr std::variant<int64_t, double> min() const {
return std::visit( return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, [this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_min()); _raw_min());
@ -215,7 +300,7 @@ class ScalarType {
} }
} }
bool operator==(ScalarType const& other) const { constexpr bool operator==(ScalarType const& other) const {
return mantissa == other.mantissa && exponent == other.exponent && return mantissa == other.mantissa && exponent == other.exponent &&
bias == other.bias && signed_ == other.signed_ && bias == other.bias && signed_ == other.signed_ &&
finite_values_only == other.finite_values_only && finite_values_only == other.finite_values_only &&
@ -240,23 +325,59 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
using Self = ScalarTypeTorch; using Self = ScalarTypeTorch;
using SelfPtr = c10::intrusive_ptr<Self>; using SelfPtr = c10::intrusive_ptr<Self>;
static void check_size_bits(int64_t size_bits, bool signed_) {
TORCH_CHECK(
size_bits <=
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
"size_bits bit width is too large to be represented");
}
static void check_bias(int64_t bias) {
using Bias = decltype(std::declval<Self>().bias);
TORCH_CHECK(bias <= std::numeric_limits<Bias>::max() &&
bias >= std::numeric_limits<Bias>::min(),
"bias too large or small to be represented");
}
static void check_exponent(int64_t exponent) {
TORCH_CHECK(
exponent <=
std::numeric_limits<decltype(std::declval<Self>().exponent)>::max(),
"exponent bit width is too large to be represented");
}
static void check_mantissa(int64_t mantissa) {
TORCH_CHECK(
mantissa <=
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
"mantissa bit width is too large to be represented");
}
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) { static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
check_size_bits(size_bits, true);
check_bias(bias.value_or(0));
return c10::make_intrusive<Self>( return c10::make_intrusive<Self>(
ScalarType::int_(size_bits, bias.value_or(0))); ScalarType::int_(size_bits, bias.value_or(0)));
} }
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) { static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
check_size_bits(size_bits, true);
check_bias(bias.value_or(0));
return c10::make_intrusive<Self>( return c10::make_intrusive<Self>(
ScalarType::uint(size_bits, bias.value_or(0))); ScalarType::uint(size_bits, bias.value_or(0)));
} }
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) { static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
check_mantissa(mantissa);
check_exponent(exponent);
return c10::make_intrusive<Self>( return c10::make_intrusive<Self>(
ScalarType::float_IEEE754(exponent, mantissa)); ScalarType::float_IEEE754(exponent, mantissa));
} }
static SelfPtr float_(int64_t exponent, int64_t mantissa, static SelfPtr float_(int64_t exponent, int64_t mantissa,
bool finite_values_only, int64_t nan_repr) { bool finite_values_only, int64_t nan_repr) {
check_mantissa(mantissa);
check_exponent(exponent);
return c10::make_intrusive<Self>(ScalarType::float_( return c10::make_intrusive<Self>(ScalarType::float_(
exponent, mantissa, finite_values_only, NanRepr(nan_repr))); exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
} }
@ -264,7 +385,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
template <typename T> template <typename T>
static void bind_readonly_property(torch::class_<Self>& cls, static void bind_readonly_property(torch::class_<Self>& cls,
std::string const& name, T Base::*field) { std::string const& name, T Base::*field) {
auto getter_func = [field = std::move(field)](SelfPtr const& self) { auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) {
if constexpr (std::is_member_function_pointer_v<decltype(field)>) { if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
return (self.get()->*field)(); return (self.get()->*field)();
} else { } else {
@ -272,6 +393,18 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
} }
}; };
auto getter_func = [field = std::move(field),
getter_func_helper = std::move(getter_func_helper)](
SelfPtr const& self) {
auto val = getter_func_helper(self);
// upconvert uint8_t, int32_t etc. to int64_t for python
if constexpr (std::is_integral_v<T>) {
return static_cast<int64_t>(val);
} else {
return val;
}
};
cls.def_property(name, getter_func); cls.def_property(name, getter_func);
} }
@ -340,6 +473,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
} }
}; };
using ScalarTypeId = int64_t;
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>; using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
// "rust style" names generally following: // "rust style" names generally following:
@ -379,4 +513,5 @@ static inline constexpr auto kHalf = kFE5M10;
static inline constexpr auto kFloat16 = kHalf; static inline constexpr auto kFloat16 = kHalf;
static inline constexpr auto kBFloat16 = kFE8M7; static inline constexpr auto kBFloat16 = kFE8M7;
static inline constexpr auto kFloat16Id = kFloat16.id();
}; // namespace vllm }; // namespace vllm

View File

@ -43,7 +43,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int size_k, int block_rows) {} int size_k, int block_rows) {}
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
const int num_bits, // number of bits used for weights const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
@ -151,20 +151,21 @@ __device__ inline uint32_t prmt(uint32_t a) {
return res; return res;
} }
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 template <typename scalar_t, vllm::ScalarTypeId w_type_id>
// values. We mostly follow the strategy in the link below, with some small __device__ inline typename ScalarType<scalar_t>::FragB dequant(int q);
// changes:
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16: // - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16: // - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
template <typename scalar_t> //
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <> template <>
__device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) { __device__ inline typename ScalarType<half>::FragB
dequant<half, vllm::kU4B8.id()>(int q) {
const int LO = 0x000f000f; const int LO = 0x000f000f;
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
const int EX = 0x64006400; const int EX = 0x64006400;
@ -187,7 +188,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
template <> template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB __device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_4bit<nv_bfloat16>(int q) { dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300; static constexpr uint32_t EX = 0x43004300;
@ -210,19 +211,64 @@ dequant_4bit<nv_bfloat16>(int q) {
return frag_b; return frag_b;
} }
template <>
__device__ inline typename ScalarType<half>::FragB
dequant<half, vllm::kU4.id()>(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or // Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference: // bf16 Reference:
// - FP16: // - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16: // - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
template <typename scalar_t> //
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <> template <>
__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) { __device__ inline typename ScalarType<half>::FragB
dequant<half, vllm::kU8B128.id()>(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464; static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
@ -242,7 +288,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
template <> template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB __device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_8bit<nv_bfloat16>(int q) { dequant<nv_bfloat16, vllm::kU8B128.id()>(int q) {
typename ScalarType<nv_bfloat16>::FragB frag_b; typename ScalarType<nv_bfloat16>::FragB frag_b;
float fp32_intermediates[4]; float fp32_intermediates[4];
@ -269,68 +315,9 @@ dequant_8bit<nv_bfloat16>(int q) {
return frag_b; return frag_b;
} }
// Zero-point dequantizers
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit_zp(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <> template <>
__device__ inline typename ScalarType<half>::FragB dequant_4bit_zp<half>( __device__ inline typename ScalarType<half>::FragB
int q) { dequant<half, vllm::kU8.id()>(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_4bit_zp<nv_bfloat16>(int q) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit_zp(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464; static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
@ -350,7 +337,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
template <> template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB __device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_8bit_zp<nv_bfloat16>(int q) { dequant<nv_bfloat16, vllm::kU8.id()>(int q) {
typename ScalarType<nv_bfloat16>::FragB frag_b; typename ScalarType<nv_bfloat16>::FragB frag_b;
float fp32_intermediates[4]; float fp32_intermediates[4];
@ -518,7 +505,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
} }
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
const int num_bits, // number of bits used for weights const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
@ -568,7 +555,9 @@ __global__ void Marlin(
using FragS = typename ScalarType<scalar_t>::FragS; using FragS = typename ScalarType<scalar_t>::FragS;
using FragZP = typename ScalarType<scalar_t>::FragZP; using FragZP = typename ScalarType<scalar_t>::FragZP;
constexpr int pack_factor = 32 / num_bits; static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
constexpr int pack_factor = 32 / w_type.size_bits();
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions // better partitioning with less reductions
@ -670,7 +659,7 @@ __global__ void Marlin(
// B sizes/strides // B sizes/strides
int b_gl_stride = 16 * prob_n / (pack_factor * 4); int b_gl_stride = 16 * prob_n / (pack_factor * 4);
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
@ -1186,19 +1175,20 @@ __global__ void Marlin(
if constexpr (has_zp) { if constexpr (has_zp) {
FragB frag_zp_0; FragB frag_zp_0;
FragB frag_zp_1; FragB frag_zp_1;
if constexpr (num_bits == 4) { int zp_quant_0, zp_quant_1;
int zp_quant = frag_qzp[k % 2][0];
int zp_quant_shift = zp_quant >> 8;
frag_zp_0 = dequant_4bit_zp<scalar_t>(zp_quant);
frag_zp_1 = dequant_4bit_zp<scalar_t>(zp_quant_shift);
if constexpr (w_type.size_bits() == 4) {
zp_quant_0 = frag_qzp[k % 2][0];
zp_quant_1 = zp_quant_0 >> 8;
} else { } else {
int zp_quant_0 = frag_qzp[k % 2][0]; static_assert(w_type.size_bits() == 8);
int zp_quant_1 = frag_qzp[k % 2][1]; zp_quant_0 = frag_qzp[k % 2][0];
frag_zp_0 = dequant_8bit_zp<scalar_t>(zp_quant_0); zp_quant_1 = frag_qzp[k % 2][1];
frag_zp_1 = dequant_8bit_zp<scalar_t>(zp_quant_1);
} }
frag_zp_0 = dequant<scalar_t, w_type_id>(zp_quant_0);
frag_zp_1 = dequant<scalar_t, w_type_id>(zp_quant_1);
frag_zp[0] = frag_zp_0[0]; frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1]; frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0]; frag_zp[2] = frag_zp_1[0];
@ -1211,32 +1201,20 @@ __global__ void Marlin(
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
FragB frag_b0; FragB frag_b0;
FragB frag_b1; FragB frag_b1;
if constexpr (num_bits == 4) { int b_quant_0, b_quant_1;
int b_quant = frag_b_quant[k % 2][0][j];
int b_quant_shift = b_quant >> 8;
if constexpr (has_zp) {
frag_b0 = dequant_4bit_zp<scalar_t>(b_quant);
frag_b1 = dequant_4bit_zp<scalar_t>(b_quant_shift);
} else {
frag_b0 = dequant_4bit<scalar_t>(b_quant);
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
}
if constexpr (w_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k % 2][0][j];
b_quant_1 = b_quant_0 >> 8;
} else { } else {
static_assert(w_type.size_bits() == 8);
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]); int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
}
if constexpr (has_zp) { frag_b0 = dequant<scalar_t, w_type_id>(b_quant_0);
frag_b0 = dequant_8bit_zp<scalar_t>(b_quant_0); frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
frag_b1 = dequant_8bit_zp<scalar_t>(b_quant_1);
} else {
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
}
}
// Apply zero-point to frag_b0 // Apply zero-point to frag_b0
if constexpr (has_zp) { if constexpr (has_zp) {
@ -1477,7 +1455,8 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for // For per-column quantization we finally apply the scale here (only for
// 4-bit) // 4-bit)
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4) {
res = __hmul2(res, s[0]); res = __hmul2(res, s[0]);
} }
@ -1605,7 +1584,7 @@ __global__ void Marlin(
// For per-column scales, we only fetch them here in the final step before // For per-column scales, we only fetch them here in the final step before
// write-out // write-out
if constexpr (!has_act_order && group_blocks == -1) { if constexpr (!has_act_order && group_blocks == -1) {
if constexpr (num_bits == 8) { if constexpr (w_type.size_bits() == 8) {
if (s_sh_wr_pred) { if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
} }
@ -1622,7 +1601,7 @@ __global__ void Marlin(
thread_block_reduce(); thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1) { if constexpr (!has_act_order && group_blocks == -1) {
if constexpr (num_bits == 8) { if constexpr (w_type.size_bits() == 8) {
cp_async_wait<0>(); cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
@ -1645,7 +1624,8 @@ __global__ void Marlin(
// For 8-bit channelwise, we apply the scale before the global reduction // For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible // that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16) // overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 8) {
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
@ -1714,20 +1694,19 @@ __global__ void Marlin(
} }
} }
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ #define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
NUM_THREADS) \ else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \ Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \ THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS>, \ HAS_ZP, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \ Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \ THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS> \ HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \ <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
@ -1923,52 +1902,52 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return exec_config_t{0, {-1, -1, -1}}; return exec_config_t{0, {-1, -1, -1}};
} }
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
\ \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\ \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\ \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\ \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\ \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\ \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\ \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template <typename scalar_t> template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
@ -2113,23 +2092,23 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
if (false) { if (false) {
} }
GPTQ_CALL_IF(4, 16, 4, 256) GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
GPTQ_CALL_IF(4, 8, 8, 256) GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256)
GPTQ_CALL_IF(4, 8, 4, 128) GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128)
GPTQ_CALL_IF(4, 4, 8, 128) GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128)
GPTQ_CALL_IF(8, 16, 4, 256) GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256)
GPTQ_CALL_IF(8, 8, 8, 256) GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256)
GPTQ_CALL_IF(8, 8, 4, 128) GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128)
GPTQ_CALL_IF(8, 4, 8, 128) GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128)
AWQ_CALL_IF(4, 16, 4, 256) AWQ_CALL_IF(vllm::kU4, 16, 4, 256)
AWQ_CALL_IF(4, 8, 8, 256) AWQ_CALL_IF(vllm::kU4, 8, 8, 256)
AWQ_CALL_IF(4, 8, 4, 128) AWQ_CALL_IF(vllm::kU4, 8, 4, 128)
AWQ_CALL_IF(4, 4, 8, 128) AWQ_CALL_IF(vllm::kU4, 4, 8, 128)
AWQ_CALL_IF(8, 16, 4, 256) AWQ_CALL_IF(vllm::kU8, 16, 4, 256)
AWQ_CALL_IF(8, 8, 8, 256) AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
AWQ_CALL_IF(8, 8, 4, 128) AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
AWQ_CALL_IF(8, 4, 8, 128) AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
else { else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
", ", prob_k, "]", ", has_act_order = ", has_act_order, ", ", prob_k, "]", ", has_act_order = ", has_act_order,