[Misc] Disambiguate quantized types via a new ScalarType (#6396)

This commit is contained in:
Lucas Wilkinson 2024-08-02 16:51:58 -04:00 committed by GitHub
parent b482b9a5b1
commit a8d604ca2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1111 additions and 356 deletions

View File

@ -66,6 +66,39 @@ endif()
#
find_package(Torch REQUIRED)
#
# Add the `default` target which detects which extensions should be
# built based on platform/architecture. This is the same logic that
# setup.py uses to select which extensions should be built and should
# be kept in sync.
#
# The `default` target makes direct use of cmake easier since knowledge
# of which extensions are supported has been factored in, e.g.
#
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
# cmake --build . --target default
#
add_custom_target(default)
message(STATUS "Enabling core extension.")
# Define _core_C extension
# built for (almost) every target platform, (excludes TPU and Neuron)
set(VLLM_EXT_SRC
"csrc/core/torch_bindings.cpp")
define_gpu_extension_target(
_core_C
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3
WITH_SOABI)
add_dependencies(default _core_C)
#
# Forward the non-CUDA device extensions to external CMake scripts.
#
@ -74,7 +107,7 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
if (VLLM_TARGET_DEVICE STREQUAL "cpu")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
else()
message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}")
return()
endif()
return()
endif()
@ -132,7 +165,7 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
endif()
#
# Define extension targets
# Define other extension targets
#
#
@ -228,21 +261,6 @@ define_gpu_extension_target(
#
# Add the `default` target which detects which extensions should be
# built based on platform/architecture. This is the same logic that
# setup.py uses to select which extensions should be built and should
# be kept in sync.
#
# The `default` target makes direct use of cmake easier since knowledge
# of which extensions are supported has been factored in, e.g.
#
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
# cmake --build . --target default
#
add_custom_target(default)
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.")
add_dependencies(default _C)

View File

@ -13,6 +13,9 @@ COPY requirements-common.txt /workspace/vllm/
COPY requirements-openvino.txt /workspace/vllm/
COPY vllm/ /workspace/vllm/vllm
COPY csrc/core /workspace/vllm/csrc/core
COPY cmake/utils.cmake /workspace/vllm/cmake/
COPY CMakeLists.txt /workspace/vllm/
COPY setup.py /workspace/vllm/
# install build requirements

View File

@ -7,16 +7,17 @@ from benchmark_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS)
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)
gptq_pack, gptq_quantize_weights, sort_weights)
from vllm.scalar_type import ScalarType
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
@ -27,13 +28,14 @@ K_FULL_OPTS = [False, True]
def bench_run(results: List[benchmark.Measurement], model: str,
act_order: bool, is_k_full: bool, num_bits: int, group_size: int,
size_m: int, size_k: int, size_n: int):
act_order: bool, is_k_full: bool, quant_type: ScalarType,
group_size: int, size_m: int, size_k: int, size_n: int):
label = "Quant Matmul"
sub_label = ("{}, act={} k_full={}, b={}, g={}, "
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
group_size, size_m, size_k, size_n))
sub_label = ("{}, act={} k_full={}, q={}, g={}, "
"MKN=({}x{}x{})".format(model, act_order, is_k_full,
str(quant_type), group_size, size_m,
size_k, size_n))
print(f"Testing: {sub_label}")
@ -50,18 +52,18 @@ def bench_run(results: List[benchmark.Measurement], model: str,
marlin_g_idx,
marlin_sort_indices,
marlin_rand_perm,
) = marlin_quantize(b, num_bits, group_size, act_order)
) = marlin_quantize(b, quant_type, group_size, act_order)
# Marlin_24 quant
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
marlin_24_s) = marlin_24_quantize(b, quant_type, group_size)
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
# GPTQ quant
(w_ref, q_w, s, g_idx,
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order)
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx"
# so that group ids are increasing
@ -75,10 +77,11 @@ def bench_run(results: List[benchmark.Measurement], model: str,
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
globals = {
# Gen params
"num_bits": num_bits,
"quant_type": quant_type,
"group_size": group_size,
"size_m": size_m,
"size_n": size_n,
@ -128,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
@ -138,19 +141,19 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time))
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
@ -160,7 +163,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
@ -196,9 +199,10 @@ def main(args):
) > 0 and is_k_full not in args.limit_k_full:
continue
for num_bits in MARLIN_SUPPORTED_NUM_BITS:
if len(args.limit_num_bits
) > 0 and num_bits not in args.limit_num_bits:
for quant_type in query_marlin_supported_quant_types(
False):
if len(args.limit_num_bits) > 0 and \
quant_type.size_bits not in args.limit_num_bits:
continue
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
@ -215,8 +219,8 @@ def main(args):
for size_m in args.batch_sizes:
bench_run(results, model, act_order, is_k_full,
num_bits, group_size, size_m, size_k,
size_n)
quant_type, group_size, size_m,
size_k, size_n)
compare = benchmark.Compare(results)
compare.print()

View File

@ -113,6 +113,5 @@ define_gpu_extension_target(
WITH_SOABI
)
add_custom_target(default)
message(STATUS "Enabling C extension.")
add_dependencies(default _C)

382
csrc/core/scalar_type.hpp Normal file
View File

@ -0,0 +1,382 @@
#pragma once
#include <torch/custom_class.h>
namespace vllm {
//
// ScalarType can represent a wide range of floating point and integer types,
// in particular it can be used to represent sub-byte data types (something
// that torch.dtype currently does not support).
//
// ScalarTypeTorch is a subclass of ScalarType that is compatible with
// TORCH_LIBRARY, making it accessible from Python as well meaning this class
// can be used as a argument for custom operators, helping to simplify these
// interfaces.
//
// The type definitions on the Python side can be found in: vllm/_core_ext.pyi
// these type definitions should be kept up to date with any Python API changes
// here.
//
class ScalarType {
public:
enum NanRepr : int64_t {
NAN_NONE = 0, // nans are not supported
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
NAN_REPR_ID_MAX
};
constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa,
int64_t bias, bool finite_values_only = false,
NanRepr nan_repr = NAN_IEEE_754)
: exponent(exponent),
mantissa(mantissa),
bias(bias),
signed_(signed_),
finite_values_only(finite_values_only),
nan_repr(nan_repr){};
static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) {
return ScalarType(true, 0, size_bits - 1, bias);
}
static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) {
return ScalarType(false, 0, size_bits, bias);
}
// IEEE 754 compliant floating point type
static constexpr ScalarType float_IEEE754(int64_t exponent,
int64_t mantissa) {
TORCH_CHECK(mantissa > 0 && exponent > 0);
return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754);
}
// IEEE 754 non-compliant floating point type
static constexpr ScalarType float_(int64_t exponent, int64_t mantissa,
bool finite_values_only,
NanRepr nan_repr) {
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
TORCH_CHECK(mantissa > 0 && exponent > 0);
TORCH_CHECK(nan_repr != NAN_IEEE_754,
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions");
return ScalarType(true, exponent, mantissa, 0, finite_values_only,
nan_repr);
}
int64_t const exponent; // size of the exponent field (0 for integer types)
int64_t const mantissa; // size of the mantissa field (size of the integer
// excluding the sign bit for integer types)
int64_t const bias; // stored values equal value + bias,
// used for quantized type
bool const signed_; // flag if the type supports negative numbers (i.e. has a
// sign bit)
// Extra Floating point info
bool const finite_values_only; // i.e. no +/-inf if true
NanRepr const nan_repr; // how NaNs are represented
// (not applicable for integer types)
int64_t size_bits() const { return mantissa + exponent + is_signed(); }
bool is_signed() const { return signed_; }
bool is_integer() const { return exponent == 0; }
bool is_floating_point() const { return exponent > 0; }
bool is_ieee_754() const {
return is_floating_point() && finite_values_only == false &&
nan_repr == NAN_IEEE_754;
}
bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; }
bool has_infs() const {
return is_floating_point() && finite_values_only == false;
}
bool has_bias() const { return bias != 0; }
private:
double _floating_point_max() const {
TORCH_CHECK(mantissa <= 52 && exponent <= 11,
"Cannot represent max/min as a double for type ", str());
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
max_mantissa -= 1;
}
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
TORCH_CHECK(exponent < 11,
"Cannot represent max/min as a double for type ", str());
max_exponent += 1;
}
// adjust the exponent to match that of a double
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
// is the exponent bits), there is some precedent for non-standard biases,
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
// but to avoid premature over complication we are just assuming the
// standard exponent bias until there is a need to support non-standard
// biases
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
uint64_t max_exponent_double =
max_exponent - exponent_bias + exponent_bias_double;
// shift the mantissa into the position for a double and
// the exponent
uint64_t double_raw =
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
return *reinterpret_cast<double*>(&double_raw);
}
std::variant<int64_t, double> _raw_max() const {
if (is_floating_point()) {
return {_floating_point_max()};
} else {
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
"Cannot represent max as a int64_t");
return {(int64_t(1) << mantissa) - 1};
}
}
std::variant<int64_t, double> _raw_min() const {
if (is_floating_point()) {
TORCH_CHECK(is_signed(),
"We currently assume all floating point types are signed");
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
double max = _floating_point_max();
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
uint64_t min_raw = max_raw | sign_bit_double;
return {*reinterpret_cast<double*>(&min_raw)};
} else {
TORCH_CHECK(!is_signed() || size_bits() <= 64,
"Cannot represent min as a int64_t");
if (is_signed()) {
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
// then perform an arithmetic shift right to set all the bits above
// (size_bits() - 1) to 1
return {INT64_MIN >> (64 - size_bits())};
} else {
return {int64_t(0)};
}
}
}
public:
// Max representable value for this scalar type.
// (accounting for bias if there is one)
std::variant<int64_t, double> max() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_max());
}
// Min representable value for this scalar type.
// (accounting for bias if there is one)
std::variant<int64_t, double> min() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_min());
}
std::string str() const {
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
* for floating point types (leading f) the scheme is:
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
* flags:
* - no-flags: means it follows IEEE 754 conventions
* - f: means finite values only (no infinities)
* - n: means nans are supported (non-standard encoding)
* for integer types the scheme is:
* `[u]int<size_bits>[b<bias>]`
* - if bias is not present it means its zero
*/
if (is_floating_point()) {
auto ret = "float" + std::to_string(size_bits()) + "_e" +
std::to_string(exponent) + "m" + std::to_string(mantissa);
if (!is_ieee_754()) {
if (finite_values_only) {
ret += "f";
}
if (nan_repr != NAN_NONE) {
ret += "n";
}
}
return ret;
} else {
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
if (has_bias()) {
ret += "b" + std::to_string(bias);
}
return ret;
}
}
bool operator==(ScalarType const& other) const {
return mantissa == other.mantissa && exponent == other.exponent &&
bias == other.bias && signed_ == other.signed_ &&
finite_values_only == other.finite_values_only &&
nan_repr == other.nan_repr;
}
};
// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from
// torch::CustomClassHolder), we use multiple inheritance here since we cannot
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
// constructor at the same time (torch::CustomClassHolder does not have a
// constexpr destructor)
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
public:
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
bool _signed)
: ScalarType(exponent, mantissa, bias, _signed){};
ScalarTypeTorch(ScalarType type) : ScalarType(type){};
using Base = ScalarType;
using Self = ScalarTypeTorch;
using SelfPtr = c10::intrusive_ptr<Self>;
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
return c10::make_intrusive<Self>(
ScalarType::int_(size_bits, bias.value_or(0)));
}
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
return c10::make_intrusive<Self>(
ScalarType::uint(size_bits, bias.value_or(0)));
}
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
return c10::make_intrusive<Self>(
ScalarType::float_IEEE754(exponent, mantissa));
}
static SelfPtr float_(int64_t exponent, int64_t mantissa,
bool finite_values_only, int64_t nan_repr) {
return c10::make_intrusive<Self>(ScalarType::float_(
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
}
template <typename T>
static void bind_readonly_property(torch::class_<Self>& cls,
std::string const& name, T Base::*field) {
auto getter_func = [field = std::move(field)](SelfPtr const& self) {
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
return (self.get()->*field)();
} else {
return self.get()->*field;
}
};
cls.def_property(name, getter_func);
}
template <typename MemberFunc, typename Cls>
static void bind_function(torch::class_<Self>& cls, const std::string& name,
MemberFunc Cls::*member) {
cls.def(name, [member = std::move(member)](SelfPtr const& self) {
return (self.get()->*member)();
});
}
template <typename Func>
static void bind_function(torch::class_<Self>& cls, const std::string& name,
Func func) {
cls.def(name, func);
}
template <typename Func>
static void bind_static_function(torch::class_<Self>& cls,
const std::string& name, Func func) {
cls.def_static(name, func);
}
static void bind_class(torch::Library& lib) {
auto cls = lib.class_<ScalarTypeTorch>("ScalarType")
.def(torch::init<int64_t, int64_t, int64_t, bool>());
// Bind Properties
bind_readonly_property(cls, "mantissa", &Base::mantissa);
bind_readonly_property(cls, "exponent", &Base::exponent);
bind_readonly_property(cls, "bias", &Base::bias);
bind_readonly_property(cls, "signed", &Base::is_signed);
bind_readonly_property(cls, "size_bits", &Base::size_bits);
// Bind member functions
bind_function(cls, "is_signed", &Base::is_signed);
bind_function(cls, "is_integer", &Base::is_integer);
bind_function(cls, "is_floating_point", &Base::is_floating_point);
bind_function(cls, "is_ieee_754", &Base::is_ieee_754);
bind_function(cls, "has_nans", &Base::has_nans);
bind_function(cls, "has_infs", &Base::has_infs);
bind_function(cls, "has_bias", &Base::has_bias);
bind_function(cls, "max", [](SelfPtr const& self) {
return std::visit([](auto arg) { return c10::IValue(arg); },
self.get()->max());
});
bind_function(cls, "min", [](SelfPtr const& self) {
return std::visit([](auto arg) { return c10::IValue(arg); },
self.get()->min());
});
bind_function(cls, "__str__", &Base::str);
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
return *self == *other;
});
bind_function(cls, "__repr__", [](SelfPtr const& self) {
return "ScalarType." + self.get()->str();
});
// Bind static functions (convenience constructors)
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754);
bind_static_function(cls, "float_", &ScalarTypeTorch::float_);
}
};
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
// "rust style" names generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
static inline constexpr auto kS4 = ScalarType::int_(4);
static inline constexpr auto kU4 = ScalarType::uint(4);
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
static inline constexpr auto kS8 = ScalarType::int_(8);
static inline constexpr auto kU8 = ScalarType::uint(8);
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
static inline constexpr auto kFE3M2f =
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn =
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
// Fixed width style names, generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
static inline constexpr auto kInt4 = kS4;
static inline constexpr auto kUint4 = kU4;
static inline constexpr auto kUint4b8 = kU4B8;
static inline constexpr auto kInt8 = kS8;
static inline constexpr auto kUint8 = kU8;
static inline constexpr auto kUint8b128 = kU8B128;
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
// colloquial names
static inline constexpr auto kHalf = kFE5M10;
static inline constexpr auto kFloat16 = kHalf;
static inline constexpr auto kBFloat16 = kFE8M7;
}; // namespace vllm

View File

@ -0,0 +1,16 @@
#include <torch/library.h>
#include "scalar_type.hpp"
#include "registration.h"
// Note the CORE exstension will be built for (almost) all hardware targets so
// new additions must account for this. (currently not built for TPU and Neuron)
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) {
// ScalarType, a custom class for representing data types that supports
// quantized types, declared here so it can be used when creating interfaces
// for custom ops.
vllm::ScalarTypeTorch::bind_class(lib);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@ -1,6 +1,6 @@
#include "cache.h"
#include "ops.h"
#include "registration.h"
#include "core/registration.h"
#include <torch/library.h>

View File

@ -1,4 +1,4 @@
#include "registration.h"
#include "core/registration.h"
#include "moe_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {

View File

@ -3,6 +3,8 @@
#include <optional>
#include <torch/library.h>
#include "core/scalar_type.hpp"
void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
@ -84,14 +86,16 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace, int64_t num_bits,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace, int64_t num_bits,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp,
bool use_fp32_reduce);

View File

@ -21,6 +21,7 @@
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
@ -71,14 +72,15 @@ __global__ void Marlin(
bool use_fp32_reduce // whether to use fp32 global reduce
) {}
} // namespace gptq_marlin
} // namespace marlin
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace, int64_t num_bits,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full) {
bool is_k_full, bool has_zp) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
@ -1963,18 +1965,29 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template <typename scalar_t>
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
void* s, void* zp, void* g_idx, void* perm, void* a_tmp,
int prob_m, int prob_n, int prob_k, void* workspace,
int num_bits, bool has_act_order, bool is_k_full,
bool has_zp, int num_groups, int group_size, int dev,
cudaStream_t stream, int thread_k, int thread_n, int sms,
int max_par, bool use_fp32_reduce) {
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, bool has_zp, int num_groups, int group_size,
int dev, cudaStream_t stream, int thread_k, int thread_n,
int sms, int max_par, bool use_fp32_reduce) {
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
// TODO: remove alias when we start supporting other 8bit types
int num_bits = q_type.size_bits();
int tot_m = prob_m;
int tot_m_blocks = div_ceil(tot_m, 16);
int pad = 16 * tot_m_blocks - tot_m;
@ -2126,19 +2139,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
}
}
} // namespace gptq_marlin
} // namespace marlin
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace, int64_t num_bits,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp,
bool use_fp32_reduce) {
// Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int pack_factor = 32 / num_bits;
if (has_zp) {
TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ",
b_q_type->str());
} else {
TORCH_CHECK(
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
b_q_type->str());
}
int pack_factor = 32 / b_q_type->size_bits();
// Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
@ -2265,21 +2287,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
marlin::marlin_mm_f16i4<half>(
marlin::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
marlin::marlin_mm_f16i4<nv_bfloat16>(
marlin::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else {

View File

@ -27,6 +27,7 @@
#include <iostream>
#include "common/base.h"
#include "core/scalar_type.hpp"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
@ -86,7 +87,8 @@ __global__ void Marlin_24(
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace, int64_t num_bits,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n,
int64_t size_k) {
TORCH_CHECK_NOT_IMPLEMENTED(
@ -1025,13 +1027,14 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace, int64_t num_bits,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n,
int64_t size_k) {
// Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int pack_factor = 32 / num_bits;
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str());
int pack_factor = 32 / b_q_type->size_bits();
// Verify M
TORCH_CHECK(size_m == a.size(0),
@ -1126,8 +1129,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
marlin_24::marlin_cuda_2_4(
a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
thread_m, sms, max_par);
b_q_type->size_bits(), groupsize, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par);
return c;
}

View File

@ -1,7 +1,7 @@
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include "registration.h"
#include "core/registration.h"
#include <torch/library.h>

View File

@ -271,6 +271,10 @@ def _build_custom_ops() -> bool:
return _is_cuda() or _is_hip() or _is_cpu()
def _build_core_ext() -> bool:
return not _is_neuron() and not _is_tpu()
def get_hipcc_rocm_version():
# Run the hipcc --version command
result = subprocess.run(['hipcc', '--version'],
@ -433,6 +437,9 @@ def get_requirements() -> List[str]:
ext_modules = []
if _build_core_ext():
ext_modules.append(CMakeExtension(name="vllm._core_C"))
if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
@ -477,7 +484,7 @@ setup(
extras_require={
"tensorizer": ["tensorizer>=2.9.0"],
},
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
package_data=package_data,
entry_points={
"console_scripts": [

View File

@ -1,8 +1,6 @@
import pytest
import torch
# ruff: noqa: F401
import vllm._C
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
from vllm._custom_ops import scaled_int8_quant

View File

@ -9,14 +9,14 @@ from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS,
marlin_make_empty_g_idx, marlin_permute_scales)
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
marlin_permute_scales, query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
pack_fp8_to_int32)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
@ -27,8 +27,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501
marlin_qqq_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp,
sort_weights)
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
@ -65,12 +64,13 @@ def rand_data(shape, dtype=torch.float16):
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
mnk_factors):
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
act_order, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
@ -95,11 +95,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
b_weight = rand_data((size_k, size_n))
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits,
group_size, act_order)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
b_weight, quant_type, group_size, act_order)
# Pack to GPTQ format
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
@ -108,8 +108,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Pack to Marlin format
weight_perm = get_weight_perm(num_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.gptq_marlin_repack(
@ -117,7 +118,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
sort_indices,
size_k,
size_n,
num_bits,
quant_type.size_bits,
)
torch.cuda.synchronize()
@ -128,10 +129,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
@ -150,22 +152,25 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
b_weight = rand_data((size_k, size_n))
# Quantize
w_ref, q_w, s, zp = quantize_weights_with_zp(b_weight, num_bits,
group_size)
w_ref, q_w, s, zp = quantize_weights(b_weight,
quant_type,
group_size,
zero_points=True)
# Pack to AWQ format
q_w_awq = awq_pack(q_w, num_bits, size_k, size_n)
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
# Pack to Marlin format
weight_perm = get_weight_perm(num_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.awq_marlin_repack(
q_w_awq,
size_k,
size_n,
num_bits,
quant_type.size_bits,
)
torch.cuda.synchronize()
@ -176,7 +181,8 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@ -185,7 +191,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
num_bits,
quant_type,
group_size,
mnk_factors,
act_order,
@ -211,7 +217,7 @@ def test_gptq_marlin_gemm(
b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, num_bits, group_size, act_order)
b_weight, quant_type, group_size, act_order)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
@ -226,7 +232,7 @@ def test_gptq_marlin_gemm(
g_idx,
sort_indices,
workspace.scratch,
num_bits,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
@ -248,10 +254,10 @@ def test_gptq_marlin_gemm(
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size,
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
@ -266,7 +272,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size,
b_weight = rand_data((size_k, size_n))
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size)
marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)
@ -279,7 +285,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size,
marlin_24_meta,
marlin_24_s,
workspace_24.scratch,
num_bits,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
@ -371,14 +377,15 @@ def test_fp8_marlin_gemm(
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_awq_marlin_gemm(
k_chunk,
n_chunk,
num_bits,
quant_type,
group_size,
mnk_factors,
use_fp32_reduce,
@ -396,7 +403,7 @@ def test_awq_marlin_gemm(
b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, num_bits, group_size)
b_weight, quant_type, group_size)
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
@ -414,7 +421,7 @@ def test_awq_marlin_gemm(
g_idx,
sort_indices,
workspace.scratch,
num_bits,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],

36
tests/test_scalartype.py Normal file
View File

@ -0,0 +1,36 @@
import pytest
import torch
from vllm.scalar_type import scalar_types
@pytest.mark.parametrize("type_tuple", (
(-8, 7, scalar_types.int4),
(0, 15, scalar_types.uint4),
(-8, 7, scalar_types.uint4b8),
(-128, 127, scalar_types.uint8b128),
(-28., 28., scalar_types.float6_e3m2f),
(torch.int8, scalar_types.int8),
(torch.uint8, scalar_types.uint8),
(torch.float8_e5m2, scalar_types.float8_e5m2),
(torch.float8_e4m3fn, scalar_types.float8_e4m3fn),
(torch.bfloat16, scalar_types.float16_e8m7),
(torch.float16, scalar_types.float16_e5m10),
),
ids=lambda x: str(x))
def test_scalar_type_min_max(type_tuple):
print(type_tuple)
if len(type_tuple) == 3:
min, max, t = type_tuple
else:
torch_type, t = type_tuple
if torch_type.is_floating_point:
min = torch.finfo(torch_type).min
max = torch.finfo(torch_type).max
else:
min = torch.iinfo(torch_type).min
max = torch.iinfo(torch_type).max
print(t, min, max, t.min(), t.max())
assert min == t.min()
assert max == t.max()

177
vllm/_core_ext.py Normal file
View File

@ -0,0 +1,177 @@
import importlib.util
from enum import Enum
from typing import TYPE_CHECKING, Optional, Union
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None
# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
NONE = 0 # nans are not supported
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
if TYPE_CHECKING or not core_C_available:
# On platforms were we cannot use/build the C++ core extension (i.e. namely
# neuron and tpu), we define the mock ScalarType class here that partially
# mimics the C++ ScalarType class.
#
# We also use this provide type signatures to the Python LSP for the methods
# in the C++ ScalarType class. So these type signatures should be kept
# in sync with csrc/core/scalar_type.hpp
from dataclasses import dataclass
@dataclass(frozen=True)
class ScalarType:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent: int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa: int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
bias: int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
signed: bool
"If the type is signed (i.e. has a sign bit)"
_finite_values_only: bool = False
"""
Private: if NANs are supported, used `has_infs()` instead.
"""
nan_repr: int = NanRepr.IEEE_754.value
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
@property
def size_bits(self):
return self.exponent + self.mantissa + int(self.signed)
def min(self) -> Union[int, float]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError
def max(self) -> Union[int, float]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError
def is_signed(self) -> bool:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
...
def is_floating_point(self):
"If the type is a floating point type"
return self.exponent != 0
def is_integer(self):
"If the type is an integer type"
return self.exponent == 0
def has_bias(self):
"If the type has a non-zero bias"
return self.bias != 0
def has_infs(self):
"If the type is floating point and supports infinity"
return not self._finite_values_only
def has_nans(self):
return self.nan_repr != NanRepr.NONE.value
def is_ieee_754(self) -> bool:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return self.nan_repr == NanRepr.IEEE_754.value and \
not self._finite_values_only
def __str__(self) -> str:
raise NotImplementedError
def __repr__(self) -> str:
raise NotImplementedError
#
# Convenience Constructors
#
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"Create a signed integer scalar type (size_bits includes sign-bit)."
return cls(size_bits - 1, size_bits, bias if bias else 0, True)
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"""Create a unsigned integer scalar type."""
return cls(size_bits, size_bits, bias if bias else 0, False)
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True)
@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: int):
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True, finite_values_only,
nan_repr)
elif core_C_available:
try:
import vllm._core_C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._core_C with %r", e)
ScalarType = torch.classes._core_C.ScalarType

View File

@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union
import torch
from vllm._core_ext import ScalarType
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -220,10 +221,10 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor, num_bits: int, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
workspace: torch.Tensor, b_q_type: ScalarType,
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
workspace, num_bits, size_m,
workspace, b_q_type, size_m,
size_n, size_k)
@ -279,14 +280,22 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, b_zeros: torch.Tensor,
g_idx: torch.Tensor, perm: torch.Tensor,
workspace: torch.Tensor, num_bits: int, size_m: int,
size_n: int, size_k: int, is_k_full: bool, has_zp: bool,
use_fp32_reduce: bool) -> torch.Tensor:
def gptq_marlin_gemm(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
b_zeros: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, num_bits,
g_idx, perm, workspace, b_q_type,
size_m, size_n, size_k, is_k_full,
has_zp, use_fp32_reduce)

View File

@ -10,11 +10,11 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points,
check_awq_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_awq_marlin_supported,
verify_marlin_supports_shape)
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@ -22,20 +22,31 @@ logger = init_logger(__name__)
class AWQMarlinConfig(QuantizationConfig):
"""Config class for AWQ Marlin"""
# num_bits -> type
TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
lm_head_quantized: bool) -> None:
self.weight_bits = weight_bits
self.pack_factor = 32 // self.weight_bits # packed into 32bits
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.has_zp = has_zp
self.lm_head_quantized = lm_head_quantized
verify_awq_marlin_supported(num_bits=self.weight_bits,
group_size=self.group_size,
has_zp=self.has_zp)
if weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {weight_bits}. "
f"Supported num_bits = {self.TYPE_MAP.keys()}")
self.quant_type = self.TYPE_MAP[weight_bits]
verify_marlin_supported(self.quant_type,
group_size=self.group_size,
has_zp=self.has_zp)
def __repr__(self) -> str:
return (f"AWQMarlinConfig(weight_bits={self.weight_bits}, "
return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"has_zp={self.has_zp}, "
f"lm_head_quantized={self.lm_head_quantized})")
@ -110,11 +121,13 @@ class AWQMarlinConfig(QuantizationConfig):
if (num_bits is None or group_size is None or has_zp is None):
return False
return check_awq_marlin_supported(
num_bits=num_bits,
group_size=group_size,
has_zp=has_zp,
min_capability=cls.get_min_capability())
if num_bits not in cls.TYPE_MAP:
return False
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
group_size=group_size,
has_zp=has_zp,
min_capability=cls.get_min_capability())
class AWQMarlinLinearMethod(LinearMethodBase):
@ -226,7 +239,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
layer.qweight,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.weight_bits)
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from AWQ format to marlin format.
@ -242,7 +255,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
layer.qzeros,
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.weight_bits)
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qzeros", marlin_zp)
# Not-used
@ -263,7 +276,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
bias=bias)

View File

@ -9,9 +9,13 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
from vllm.model_executor.utils import set_weight_attrs
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsW4A16Sparse24"]
W4A16SPARSE24_SUPPORTED_BITS = [4]
W4A16SPARSE24_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
}
W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
@ -22,9 +26,15 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
group_size: Optional[int] = None):
self.strategy = strategy
self.group_size = group_size
self.num_bits = num_bits
self.tile_size = 16
if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}")
self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]
if self.strategy == "group" and self.group_size is None:
raise ValueError(
"group_size must be given when using strategy group")
@ -43,7 +53,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
pack_factor = 32 // self.num_bits
pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes)
qweight = Parameter(
@ -138,7 +148,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
size_n = scales.shape[1]
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace, self.num_bits, size_m,
workspace, self.quant_type, size_m,
size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))

View File

@ -8,12 +8,17 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_gptq_marlin_supported,
marlin_permute_scales, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.utils import set_weight_attrs
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_BITS = [4, 8]
WNA16_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
}
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsWNA16(CompressedTensorsScheme):
@ -22,8 +27,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
strategy: str,
num_bits: int,
group_size: Optional[int] = None):
self.num_bits = num_bits
self.pack_factor = 32 // self.num_bits
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.group_size: int
@ -37,10 +42,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
else:
self.group_size = group_size
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
# Verify supported on platform.
verify_gptq_marlin_supported(num_bits=self.num_bits,
group_size=self.group_size,
is_sym=True)
verify_marlin_supported(quant_type=self.quant_type,
group_size=self.group_size)
@classmethod
def get_min_capability(cls) -> int:
@ -150,7 +161,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.num_bits)
num_bits=self.quant_type.size_bits)
replace_tensor(layer, "weight_packed", marlin_qweight)
# Permute scales from compressed-tensors format to marlin format.
@ -172,7 +183,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
num_bits=self.num_bits,
wtype=self.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=True,

View File

@ -10,11 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full,
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_gptq_marlin_supported, verify_marlin_supports_shape)
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@ -22,6 +23,12 @@ logger = init_logger(__name__)
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool) -> None:
if desc_act and group_size == -1:
@ -29,20 +36,23 @@ class GPTQMarlinConfig(QuantizationConfig):
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.pack_factor = 32 // self.weight_bits # packed into int32
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.lm_head_quantized = lm_head_quantized
if (weight_bits, is_sym) not in self.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
f"bits={weight_bits}, sym={is_sym}")
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
# Verify supported on platform.
verify_gptq_marlin_supported(num_bits=self.weight_bits,
group_size=self.group_size,
is_sym=self.is_sym)
verify_marlin_supported(quant_type=self.quant_type,
group_size=self.group_size)
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized})")
@ -122,11 +132,12 @@ class GPTQMarlinConfig(QuantizationConfig):
or desc_act is None):
return False
return check_gptq_marlin_supported(
num_bits=num_bits,
group_size=group_size,
is_sym=sym,
min_capability=cls.get_min_capability())
if (num_bits, sym) not in cls.TYPE_MAP:
return False
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size,
min_capability=cls.get_min_capability())
class GPTQMarlinLinearMethod(LinearMethodBase):
@ -293,7 +304,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.weight_bits)
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from autogptq format to marlin format.
@ -319,7 +330,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
wtype=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full,

View File

@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
scalar_types.uint4b8, scalar_types.uint8b128
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
class GPTQMarlin24Config(QuantizationConfig):
@ -31,14 +33,19 @@ class GPTQMarlin24Config(QuantizationConfig):
weight_bits: int,
group_size: int,
) -> None:
self.weight_bits = weight_bits
quant_type = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
}.get(weight_bits)
self.group_size = group_size
# Verify
if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
if quant_type is None or \
quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
raise ValueError(
f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
f"Marlin_24 does not support quant_type = {quant_type}. "
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
raise ValueError(
@ -46,8 +53,10 @@ class GPTQMarlin24Config(QuantizationConfig):
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
"are supported.")
self.quant_type = quant_type
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // self.weight_bits
self.pack_factor = 32 // self.quant_type.size_bits
# Tile size used by marlin kernels.
self.tile_size = 16
@ -66,8 +75,8 @@ class GPTQMarlin24Config(QuantizationConfig):
self.perm_len = 1024
def __repr__(self) -> str:
return "Marlin24Config(weight_bits={}, group_size={})".format(
self.weight_bits, self.group_size)
return "Marlin24Config(quant_type={}, group_size={})".format(
self.quant_type, self.group_size)
@classmethod
def get_name(cls) -> str:
@ -279,7 +288,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace,
self.quant_config.weight_bits,
self.quant_config.quant_type,
size_m, size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))

View File

@ -5,6 +5,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
from .quant_utils import pack_cols, unpack_cols
@ -13,7 +14,6 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
MARLIN_SUPPORTED_NUM_BITS = [4, 8]
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# In case there is a performance issue with Marlin, the variable below can be
@ -22,76 +22,70 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
USE_FP32_REDUCE_DEFAULT = True
def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
min_capability: Optional[int],
has_zp: bool) -> Tuple[bool, Optional[str]]:
if min_capability is not None:
# For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(has_zp: bool,
min_capability: Optional[int] = None):
if min_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if device_capability < min_capability:
return (False, "Marlin does not support device_capability = {}"
", the min_capability required is {}".format(
device_capability, min_capability))
min_capability = major * 10 + minor
if num_bits not in MARLIN_SUPPORTED_NUM_BITS:
return (False, "Marlin does not support weight_bits = {}. "
"Only weight_bits = {} are supported.".format(
num_bits, MARLIN_SUPPORTED_NUM_BITS))
if min_capability < 80:
return []
if group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return (False, "Marlin does not support group_size = {}. Only "
"group_sizes = {} are supported.".format(
group_size, MARLIN_SUPPORTED_GROUP_SIZES))
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
# to add `scalar_types.float8_e4m3fn` here
return [scalar_types.uint4b8, scalar_types.uint8b128]
if not has_zp and not is_sym:
return (False,
"Marlin without zero_points must have symmetric quantization")
def _check_marlin_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if min_capability is None:
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor
supported_types = query_marlin_supported_quant_types(
has_zp, min_capability)
if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"min_capability = {min_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
return (False, f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True, None
def check_gptq_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
min_capability: int) -> bool:
cond, _ = _check_marlin_supported(num_bits,
group_size,
is_sym,
min_capability,
has_zp=False)
def check_marlin_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
min_capability: Optional[int] = None) -> bool:
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
min_capability)
return cond
def check_awq_marlin_supported(num_bits: int, group_size: int, has_zp: bool,
min_capability: int) -> bool:
cond, _ = _check_marlin_supported(num_bits,
group_size,
False,
min_capability,
has_zp=has_zp)
return cond
def verify_gptq_marlin_supported(num_bits: int, group_size: int,
is_sym: bool) -> None:
cond, err_msg = _check_marlin_supported(num_bits,
group_size,
is_sym,
min_capability=None,
has_zp=False)
def verify_marlin_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False) -> None:
cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
if not cond:
assert err_msg is not None
raise ValueError("GPTQ" + err_msg)
def verify_awq_marlin_supported(num_bits: int, group_size: int,
has_zp: bool) -> None:
cond, err_msg = _check_marlin_supported(num_bits,
group_size,
False,
min_capability=None,
has_zp=has_zp)
if not cond:
assert err_msg is not None
raise ValueError("AWQ" + err_msg)
raise ValueError(err_msg)
def verify_marlin_supports_shape(output_size_per_partition: int,
@ -245,7 +239,7 @@ def apply_gptq_marlin_linear(
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
wtype: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
@ -261,7 +255,7 @@ def apply_gptq_marlin_linear(
g_idx,
g_idx_sort_indices,
workspace,
num_bits,
wtype,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
@ -283,7 +277,7 @@ def apply_awq_marlin_linear(
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
bias: Optional[torch.Tensor] = None,
@ -298,7 +292,7 @@ def apply_awq_marlin_linear(
g_idx,
g_idx_sort_indices,
workspace,
num_bits,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,

View File

@ -5,10 +5,12 @@ from typing import List
import numpy as np
import torch
from vllm.scalar_type import ScalarType
from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales,
marlin_zero_points)
from .quant_utils import (get_pack_factor, quantize_weights,
quantize_weights_with_zp, sort_weights)
from .quant_utils import (get_pack_factor, gptq_quantize_weights,
quantize_weights, sort_weights)
class MarlinWorkspace:
@ -90,9 +92,10 @@ def get_weight_perm(num_bits: int):
return perm
def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
act_order: bool):
size_k, size_n = w.shape
num_bits = quant_type.size_bits
# Normalize group_size
if group_size == -1:
@ -100,8 +103,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
assert group_size <= size_k
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
act_order)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
w, quant_type, group_size, act_order)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
@ -122,7 +125,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
return res_list
def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType,
group_size: int):
size_k, size_n = w.shape
# Normalize group_size
@ -135,13 +139,18 @@ def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
num_groups = size_k // group_size
# Quantize with zp
w_ref, q_w, s, zp = quantize_weights_with_zp(w, num_bits, group_size)
w_ref, q_w, s, zp = quantize_weights(w,
quant_type,
group_size,
zero_points=True)
# Reformat to marlin
weight_perm = get_weight_perm(num_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
marlin_zp = marlin_zero_points(zp, num_groups, size_n, num_bits)
marlin_zp = marlin_zero_points(zp, num_groups, size_n,
quant_type.size_bits)
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]

View File

@ -6,8 +6,10 @@ from typing import List
import numpy
import torch
from vllm.scalar_type import ScalarType
from .marlin_utils_test import marlin_weights
from .quant_utils import quantize_weights
from .quant_utils import gptq_quantize_weights
# This is PyTorch implementation of main part of reorder_meta()
@ -348,13 +350,11 @@ def check_24(w, num_rows_to_sample=50, _verbose=False):
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
assert q_24.shape == (size_k, size_n)
# Remove zp to normalize over 0
max_q_val = (1 << num_bits) - 1
zp = (max_q_val + 1) // 2
q_24_no_zp = q_24 - zp
# Remove bias to normalize over 0
q_24_no_zp = q_24 - wtype.bias
# Compress
q_24_no_zp = q_24_no_zp.t().contiguous()
@ -362,8 +362,8 @@ def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
q_24_no_zp)
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
# Restore zp
q_24_comp = q_24_no_zp_comp + zp
# Restore bias
q_24_comp = q_24_no_zp_comp + wtype.bias
# Resize meta to its actual shape (without moving any data)
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
@ -427,7 +427,7 @@ def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int,
def marlin_24_quantize(
w: torch.Tensor,
num_bits: int,
quant_type: ScalarType,
group_size: int,
):
size_k, size_n = w.shape
@ -441,20 +441,18 @@ def marlin_24_quantize(
w_24, mask_24 = inject_24(w, size_k, size_n)
# Quantize
w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
num_bits,
group_size,
act_order=False)
w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
w_24, quant_type, group_size, act_order=False)
# Compress quantized weight
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
num_bits)
quant_type)
size_k_comp = size_k // 2
# Reformat to marlin
weight_perm = get_weight_perm_24(num_bits)
weight_perm = get_weight_perm_24(quant_type.size_bits)
marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
num_bits, weight_perm)
quant_type.size_bits, weight_perm)
marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
# Create result

View File

@ -4,7 +4,11 @@ from typing import List
import numpy
import torch
SUPPORTED_NUM_BITS = [4, 8]
from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Note: this is a hack. We should update each model to register the
@ -45,7 +49,7 @@ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
def get_pack_factor(num_bits):
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
@ -74,24 +78,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
)
def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
act_order: bool):
def quantize_weights(w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
zero_points: bool = False):
assert quant_type.is_integer(), \
"Floating point quantization may work but has not been tested"
orig_device = w.device
orig_type = w.dtype
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
if group_size == -1:
group_size = size_k
assert group_size <= size_k
max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2
# Reshape to [groupsize, -1]
if group_size < size_k:
w = w.reshape((-1, group_size, size_n))
@ -99,16 +102,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
w = w.reshape((group_size, -1))
# Compute scale for each group
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
s *= 2 / max_q_val # 2 => symmetric
max_val = torch.max(w, 0, keepdim=True).values
min_val = torch.min(w, 0, keepdim=True).values
max_q_val = quant_type.max()
min_q_val = quant_type.min()
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
.clamp(min_q_val, max_q_val).int()
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
maybe_w_zp = None
# Quantize
q_w = torch.round(w / s).int()
q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val)
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
w_q = torch.clamp(w_q, min_q_val, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
if quant_type.has_bias():
w_q += quant_type.bias
# Restore original shapes
if group_size < size_k:
@ -119,10 +140,35 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
w = w.reshape((size_k, size_n)).contiguous()
return w
q_w = reshape_w(q_w)
w_q = reshape_w(w_q)
w_ref = reshape_w(w_ref)
s = s.reshape((-1, size_n)).contiguous()
w_s = w_s.reshape((-1, size_n)).contiguous()
if zero_points:
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
maybe_w_zp = maybe_w_zp.to(device=orig_device)
return (
w_ref.to(device=orig_device),
w_q.to(device=orig_device),
w_s.to(device=orig_device),
maybe_w_zp,
)
def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
group_size: int, act_order: bool):
size_k, _ = w.shape
assert w.is_floating_point(), "w must be float"
assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \
f"Unsupported gptq type = {quant_type}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
# Apply act_order
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
@ -133,76 +179,9 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k)
w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s.to(device=orig_device),
g_idx.to(device=orig_device),
rand_perm.to(device=orig_device),
)
def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
orig_device = w.device
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
if group_size == -1:
group_size = size_k
assert group_size <= size_k
max_q_val = 2**num_bits - 1
min_q_val = 0
# Reshape to [groupsize, -1]
if group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
max = torch.max(w, 0, keepdim=True)[0]
min = torch.min(w, 0, keepdim=True)[0]
s = (max - min).clamp(min=1e-5) / max_q_val
# Compute zero-point for each group
zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int()
# Quantize
q_w = torch.round(w / s).int() + zp
q_w = torch.clamp(q_w, min_q_val, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - zp).half() * s
# Restore original shapes
if group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)
s = s.reshape((-1, size_n)).contiguous()
zp = zp.reshape((-1, size_n)).contiguous()
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s.to(device=orig_device),
zp.to(device=orig_device),
)
return w_ref, w_q, w_s, g_idx, rand_perm
# QQQ employs different quant schemes for per-group and
@ -212,7 +191,8 @@ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \
f"Unsupported num_bits = {num_bits}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"

35
vllm/scalar_type.py Normal file
View File

@ -0,0 +1,35 @@
from ._core_ext import NanRepr, ScalarType
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
# flags:
# - no-flags: means it follows IEEE 754 conventions
# - f: means finite values only (no infinities)
# - n: means nans are supported (non-standard encoding)
# for integer types the scheme is:
# `[u]int<size_bits>[b<bias>]`
# - if bias is not present it means its zero
class scalar_types:
int4 = ScalarType.int_(4, None)
uint4 = ScalarType.uint(4, None)
int8 = ScalarType.int_(8, None)
uint8 = ScalarType.uint(8, None)
float8_e4m3fn = ScalarType.float_(4, 3, True,
NanRepr.EXTD_RANGE_MAX_MIN.value)
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)
# "gptq" types
uint4b8 = ScalarType.uint(4, 8)
uint8b128 = ScalarType.uint(8, 128)
# colloquial names
bfloat16 = float16_e8m7
float16 = float16_e5m10