[Misc] Disambiguate quantized types via a new ScalarType (#6396)
This commit is contained in:
parent
b482b9a5b1
commit
a8d604ca2a
@ -66,6 +66,39 @@ endif()
|
||||
#
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
#
|
||||
# Add the `default` target which detects which extensions should be
|
||||
# built based on platform/architecture. This is the same logic that
|
||||
# setup.py uses to select which extensions should be built and should
|
||||
# be kept in sync.
|
||||
#
|
||||
# The `default` target makes direct use of cmake easier since knowledge
|
||||
# of which extensions are supported has been factored in, e.g.
|
||||
#
|
||||
# mkdir build && cd build
|
||||
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
|
||||
# cmake --build . --target default
|
||||
#
|
||||
add_custom_target(default)
|
||||
message(STATUS "Enabling core extension.")
|
||||
|
||||
# Define _core_C extension
|
||||
# built for (almost) every target platform, (excludes TPU and Neuron)
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/core/torch_bindings.cpp")
|
||||
|
||||
define_gpu_extension_target(
|
||||
_core_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
add_dependencies(default _core_C)
|
||||
|
||||
#
|
||||
# Forward the non-CUDA device extensions to external CMake scripts.
|
||||
#
|
||||
@ -74,7 +107,7 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
|
||||
if (VLLM_TARGET_DEVICE STREQUAL "cpu")
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}")
|
||||
return()
|
||||
endif()
|
||||
return()
|
||||
endif()
|
||||
@ -132,7 +165,7 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
# Define other extension targets
|
||||
#
|
||||
|
||||
#
|
||||
@ -228,21 +261,6 @@ define_gpu_extension_target(
|
||||
|
||||
|
||||
|
||||
#
|
||||
# Add the `default` target which detects which extensions should be
|
||||
# built based on platform/architecture. This is the same logic that
|
||||
# setup.py uses to select which extensions should be built and should
|
||||
# be kept in sync.
|
||||
#
|
||||
# The `default` target makes direct use of cmake easier since knowledge
|
||||
# of which extensions are supported has been factored in, e.g.
|
||||
#
|
||||
# mkdir build && cd build
|
||||
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
|
||||
# cmake --build . --target default
|
||||
#
|
||||
add_custom_target(default)
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
message(STATUS "Enabling C extension.")
|
||||
add_dependencies(default _C)
|
||||
|
@ -13,6 +13,9 @@ COPY requirements-common.txt /workspace/vllm/
|
||||
COPY requirements-openvino.txt /workspace/vllm/
|
||||
|
||||
COPY vllm/ /workspace/vllm/vllm
|
||||
COPY csrc/core /workspace/vllm/csrc/core
|
||||
COPY cmake/utils.cmake /workspace/vllm/cmake/
|
||||
COPY CMakeLists.txt /workspace/vllm/
|
||||
COPY setup.py /workspace/vllm/
|
||||
|
||||
# install build requirements
|
||||
|
@ -7,16 +7,17 @@ from benchmark_shapes import WEIGHT_SHAPES
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS)
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace, marlin_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, quantize_weights, sort_weights)
|
||||
gptq_pack, gptq_quantize_weights, sort_weights)
|
||||
from vllm.scalar_type import ScalarType
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||
@ -27,13 +28,14 @@ K_FULL_OPTS = [False, True]
|
||||
|
||||
|
||||
def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
act_order: bool, is_k_full: bool, num_bits: int, group_size: int,
|
||||
size_m: int, size_k: int, size_n: int):
|
||||
act_order: bool, is_k_full: bool, quant_type: ScalarType,
|
||||
group_size: int, size_m: int, size_k: int, size_n: int):
|
||||
label = "Quant Matmul"
|
||||
|
||||
sub_label = ("{}, act={} k_full={}, b={}, g={}, "
|
||||
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
|
||||
group_size, size_m, size_k, size_n))
|
||||
sub_label = ("{}, act={} k_full={}, q={}, g={}, "
|
||||
"MKN=({}x{}x{})".format(model, act_order, is_k_full,
|
||||
str(quant_type), group_size, size_m,
|
||||
size_k, size_n))
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
@ -50,18 +52,18 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
marlin_rand_perm,
|
||||
) = marlin_quantize(b, num_bits, group_size, act_order)
|
||||
) = marlin_quantize(b, quant_type, group_size, act_order)
|
||||
|
||||
# Marlin_24 quant
|
||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
|
||||
marlin_24_s) = marlin_24_quantize(b, quant_type, group_size)
|
||||
|
||||
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
|
||||
# GPTQ quant
|
||||
(w_ref, q_w, s, g_idx,
|
||||
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
|
||||
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
|
||||
rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order)
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx"
|
||||
# so that group ids are increasing
|
||||
@ -75,10 +77,11 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
|
||||
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
|
||||
|
||||
globals = {
|
||||
# Gen params
|
||||
"num_bits": num_bits,
|
||||
"quant_type": quant_type,
|
||||
"group_size": group_size,
|
||||
"size_m": size_m,
|
||||
"size_n": size_n,
|
||||
@ -128,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -138,19 +141,19 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm_fp32",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
|
||||
if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
|
||||
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -160,7 +163,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
|
||||
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -196,9 +199,10 @@ def main(args):
|
||||
) > 0 and is_k_full not in args.limit_k_full:
|
||||
continue
|
||||
|
||||
for num_bits in MARLIN_SUPPORTED_NUM_BITS:
|
||||
if len(args.limit_num_bits
|
||||
) > 0 and num_bits not in args.limit_num_bits:
|
||||
for quant_type in query_marlin_supported_quant_types(
|
||||
False):
|
||||
if len(args.limit_num_bits) > 0 and \
|
||||
quant_type.size_bits not in args.limit_num_bits:
|
||||
continue
|
||||
|
||||
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
@ -215,8 +219,8 @@ def main(args):
|
||||
|
||||
for size_m in args.batch_sizes:
|
||||
bench_run(results, model, act_order, is_k_full,
|
||||
num_bits, group_size, size_m, size_k,
|
||||
size_n)
|
||||
quant_type, group_size, size_m,
|
||||
size_k, size_n)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
@ -113,6 +113,5 @@ define_gpu_extension_target(
|
||||
WITH_SOABI
|
||||
)
|
||||
|
||||
add_custom_target(default)
|
||||
message(STATUS "Enabling C extension.")
|
||||
add_dependencies(default _C)
|
||||
|
382
csrc/core/scalar_type.hpp
Normal file
382
csrc/core/scalar_type.hpp
Normal file
@ -0,0 +1,382 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
//
|
||||
// ScalarType can represent a wide range of floating point and integer types,
|
||||
// in particular it can be used to represent sub-byte data types (something
|
||||
// that torch.dtype currently does not support).
|
||||
//
|
||||
// ScalarTypeTorch is a subclass of ScalarType that is compatible with
|
||||
// TORCH_LIBRARY, making it accessible from Python as well meaning this class
|
||||
// can be used as a argument for custom operators, helping to simplify these
|
||||
// interfaces.
|
||||
//
|
||||
// The type definitions on the Python side can be found in: vllm/_core_ext.pyi
|
||||
// these type definitions should be kept up to date with any Python API changes
|
||||
// here.
|
||||
//
|
||||
class ScalarType {
|
||||
public:
|
||||
enum NanRepr : int64_t {
|
||||
NAN_NONE = 0, // nans are not supported
|
||||
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
||||
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
||||
|
||||
NAN_REPR_ID_MAX
|
||||
};
|
||||
|
||||
constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa,
|
||||
int64_t bias, bool finite_values_only = false,
|
||||
NanRepr nan_repr = NAN_IEEE_754)
|
||||
: exponent(exponent),
|
||||
mantissa(mantissa),
|
||||
bias(bias),
|
||||
signed_(signed_),
|
||||
finite_values_only(finite_values_only),
|
||||
nan_repr(nan_repr){};
|
||||
|
||||
static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) {
|
||||
return ScalarType(true, 0, size_bits - 1, bias);
|
||||
}
|
||||
|
||||
static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) {
|
||||
return ScalarType(false, 0, size_bits, bias);
|
||||
}
|
||||
|
||||
// IEEE 754 compliant floating point type
|
||||
static constexpr ScalarType float_IEEE754(int64_t exponent,
|
||||
int64_t mantissa) {
|
||||
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
||||
return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754);
|
||||
}
|
||||
|
||||
// IEEE 754 non-compliant floating point type
|
||||
static constexpr ScalarType float_(int64_t exponent, int64_t mantissa,
|
||||
bool finite_values_only,
|
||||
NanRepr nan_repr) {
|
||||
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
||||
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
||||
TORCH_CHECK(nan_repr != NAN_IEEE_754,
|
||||
"use `float_IEEE754` constructor for floating point types that "
|
||||
"follow IEEE 754 conventions");
|
||||
return ScalarType(true, exponent, mantissa, 0, finite_values_only,
|
||||
nan_repr);
|
||||
}
|
||||
|
||||
int64_t const exponent; // size of the exponent field (0 for integer types)
|
||||
int64_t const mantissa; // size of the mantissa field (size of the integer
|
||||
// excluding the sign bit for integer types)
|
||||
int64_t const bias; // stored values equal value + bias,
|
||||
// used for quantized type
|
||||
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
||||
// sign bit)
|
||||
|
||||
// Extra Floating point info
|
||||
bool const finite_values_only; // i.e. no +/-inf if true
|
||||
NanRepr const nan_repr; // how NaNs are represented
|
||||
// (not applicable for integer types)
|
||||
|
||||
int64_t size_bits() const { return mantissa + exponent + is_signed(); }
|
||||
bool is_signed() const { return signed_; }
|
||||
bool is_integer() const { return exponent == 0; }
|
||||
bool is_floating_point() const { return exponent > 0; }
|
||||
bool is_ieee_754() const {
|
||||
return is_floating_point() && finite_values_only == false &&
|
||||
nan_repr == NAN_IEEE_754;
|
||||
}
|
||||
bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; }
|
||||
bool has_infs() const {
|
||||
return is_floating_point() && finite_values_only == false;
|
||||
}
|
||||
bool has_bias() const { return bias != 0; }
|
||||
|
||||
private:
|
||||
double _floating_point_max() const {
|
||||
TORCH_CHECK(mantissa <= 52 && exponent <= 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
|
||||
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
|
||||
max_mantissa -= 1;
|
||||
}
|
||||
|
||||
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
|
||||
TORCH_CHECK(exponent < 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
max_exponent += 1;
|
||||
}
|
||||
|
||||
// adjust the exponent to match that of a double
|
||||
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
|
||||
// is the exponent bits), there is some precedent for non-standard biases,
|
||||
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
|
||||
// but to avoid premature over complication we are just assuming the
|
||||
// standard exponent bias until there is a need to support non-standard
|
||||
// biases
|
||||
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
|
||||
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
|
||||
|
||||
uint64_t max_exponent_double =
|
||||
max_exponent - exponent_bias + exponent_bias_double;
|
||||
|
||||
// shift the mantissa into the position for a double and
|
||||
// the exponent
|
||||
uint64_t double_raw =
|
||||
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
|
||||
|
||||
return *reinterpret_cast<double*>(&double_raw);
|
||||
}
|
||||
|
||||
std::variant<int64_t, double> _raw_max() const {
|
||||
if (is_floating_point()) {
|
||||
return {_floating_point_max()};
|
||||
} else {
|
||||
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
|
||||
"Cannot represent max as a int64_t");
|
||||
return {(int64_t(1) << mantissa) - 1};
|
||||
}
|
||||
}
|
||||
|
||||
std::variant<int64_t, double> _raw_min() const {
|
||||
if (is_floating_point()) {
|
||||
TORCH_CHECK(is_signed(),
|
||||
"We currently assume all floating point types are signed");
|
||||
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
|
||||
|
||||
double max = _floating_point_max();
|
||||
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
|
||||
uint64_t min_raw = max_raw | sign_bit_double;
|
||||
return {*reinterpret_cast<double*>(&min_raw)};
|
||||
} else {
|
||||
TORCH_CHECK(!is_signed() || size_bits() <= 64,
|
||||
"Cannot represent min as a int64_t");
|
||||
if (is_signed()) {
|
||||
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
|
||||
// then perform an arithmetic shift right to set all the bits above
|
||||
// (size_bits() - 1) to 1
|
||||
return {INT64_MIN >> (64 - size_bits())};
|
||||
} else {
|
||||
return {int64_t(0)};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Max representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
std::variant<int64_t, double> max() const {
|
||||
return std::visit(
|
||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||
_raw_max());
|
||||
}
|
||||
|
||||
// Min representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
std::variant<int64_t, double> min() const {
|
||||
return std::visit(
|
||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||
_raw_min());
|
||||
}
|
||||
|
||||
std::string str() const {
|
||||
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
* for floating point types (leading f) the scheme is:
|
||||
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
* flags:
|
||||
* - no-flags: means it follows IEEE 754 conventions
|
||||
* - f: means finite values only (no infinities)
|
||||
* - n: means nans are supported (non-standard encoding)
|
||||
* for integer types the scheme is:
|
||||
* `[u]int<size_bits>[b<bias>]`
|
||||
* - if bias is not present it means its zero
|
||||
*/
|
||||
if (is_floating_point()) {
|
||||
auto ret = "float" + std::to_string(size_bits()) + "_e" +
|
||||
std::to_string(exponent) + "m" + std::to_string(mantissa);
|
||||
if (!is_ieee_754()) {
|
||||
if (finite_values_only) {
|
||||
ret += "f";
|
||||
}
|
||||
if (nan_repr != NAN_NONE) {
|
||||
ret += "n";
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
} else {
|
||||
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
|
||||
if (has_bias()) {
|
||||
ret += "b" + std::to_string(bias);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
bool operator==(ScalarType const& other) const {
|
||||
return mantissa == other.mantissa && exponent == other.exponent &&
|
||||
bias == other.bias && signed_ == other.signed_ &&
|
||||
finite_values_only == other.finite_values_only &&
|
||||
nan_repr == other.nan_repr;
|
||||
}
|
||||
};
|
||||
|
||||
// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from
|
||||
// torch::CustomClassHolder), we use multiple inheritance here since we cannot
|
||||
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
|
||||
// constructor at the same time (torch::CustomClassHolder does not have a
|
||||
// constexpr destructor)
|
||||
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
||||
public:
|
||||
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
|
||||
bool _signed)
|
||||
: ScalarType(exponent, mantissa, bias, _signed){};
|
||||
|
||||
ScalarTypeTorch(ScalarType type) : ScalarType(type){};
|
||||
|
||||
using Base = ScalarType;
|
||||
using Self = ScalarTypeTorch;
|
||||
using SelfPtr = c10::intrusive_ptr<Self>;
|
||||
|
||||
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
|
||||
return c10::make_intrusive<Self>(
|
||||
ScalarType::int_(size_bits, bias.value_or(0)));
|
||||
}
|
||||
|
||||
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
|
||||
return c10::make_intrusive<Self>(
|
||||
ScalarType::uint(size_bits, bias.value_or(0)));
|
||||
}
|
||||
|
||||
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
|
||||
return c10::make_intrusive<Self>(
|
||||
ScalarType::float_IEEE754(exponent, mantissa));
|
||||
}
|
||||
|
||||
static SelfPtr float_(int64_t exponent, int64_t mantissa,
|
||||
bool finite_values_only, int64_t nan_repr) {
|
||||
return c10::make_intrusive<Self>(ScalarType::float_(
|
||||
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void bind_readonly_property(torch::class_<Self>& cls,
|
||||
std::string const& name, T Base::*field) {
|
||||
auto getter_func = [field = std::move(field)](SelfPtr const& self) {
|
||||
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
|
||||
return (self.get()->*field)();
|
||||
} else {
|
||||
return self.get()->*field;
|
||||
}
|
||||
};
|
||||
|
||||
cls.def_property(name, getter_func);
|
||||
}
|
||||
|
||||
template <typename MemberFunc, typename Cls>
|
||||
static void bind_function(torch::class_<Self>& cls, const std::string& name,
|
||||
MemberFunc Cls::*member) {
|
||||
cls.def(name, [member = std::move(member)](SelfPtr const& self) {
|
||||
return (self.get()->*member)();
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
static void bind_function(torch::class_<Self>& cls, const std::string& name,
|
||||
Func func) {
|
||||
cls.def(name, func);
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
static void bind_static_function(torch::class_<Self>& cls,
|
||||
const std::string& name, Func func) {
|
||||
cls.def_static(name, func);
|
||||
}
|
||||
|
||||
static void bind_class(torch::Library& lib) {
|
||||
auto cls = lib.class_<ScalarTypeTorch>("ScalarType")
|
||||
.def(torch::init<int64_t, int64_t, int64_t, bool>());
|
||||
|
||||
// Bind Properties
|
||||
bind_readonly_property(cls, "mantissa", &Base::mantissa);
|
||||
bind_readonly_property(cls, "exponent", &Base::exponent);
|
||||
bind_readonly_property(cls, "bias", &Base::bias);
|
||||
bind_readonly_property(cls, "signed", &Base::is_signed);
|
||||
bind_readonly_property(cls, "size_bits", &Base::size_bits);
|
||||
|
||||
// Bind member functions
|
||||
bind_function(cls, "is_signed", &Base::is_signed);
|
||||
bind_function(cls, "is_integer", &Base::is_integer);
|
||||
bind_function(cls, "is_floating_point", &Base::is_floating_point);
|
||||
bind_function(cls, "is_ieee_754", &Base::is_ieee_754);
|
||||
bind_function(cls, "has_nans", &Base::has_nans);
|
||||
bind_function(cls, "has_infs", &Base::has_infs);
|
||||
bind_function(cls, "has_bias", &Base::has_bias);
|
||||
|
||||
bind_function(cls, "max", [](SelfPtr const& self) {
|
||||
return std::visit([](auto arg) { return c10::IValue(arg); },
|
||||
self.get()->max());
|
||||
});
|
||||
bind_function(cls, "min", [](SelfPtr const& self) {
|
||||
return std::visit([](auto arg) { return c10::IValue(arg); },
|
||||
self.get()->min());
|
||||
});
|
||||
|
||||
bind_function(cls, "__str__", &Base::str);
|
||||
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
|
||||
return *self == *other;
|
||||
});
|
||||
bind_function(cls, "__repr__", [](SelfPtr const& self) {
|
||||
return "ScalarType." + self.get()->str();
|
||||
});
|
||||
|
||||
// Bind static functions (convenience constructors)
|
||||
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
|
||||
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
|
||||
bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754);
|
||||
bind_static_function(cls, "float_", &ScalarTypeTorch::float_);
|
||||
}
|
||||
};
|
||||
|
||||
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
|
||||
|
||||
// "rust style" names generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
||||
static inline constexpr auto kS4 = ScalarType::int_(4);
|
||||
static inline constexpr auto kU4 = ScalarType::uint(4);
|
||||
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
|
||||
static inline constexpr auto kS8 = ScalarType::int_(8);
|
||||
static inline constexpr auto kU8 = ScalarType::uint(8);
|
||||
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
||||
|
||||
static inline constexpr auto kFE3M2f =
|
||||
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn =
|
||||
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
|
||||
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
|
||||
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
|
||||
|
||||
// Fixed width style names, generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
|
||||
static inline constexpr auto kInt4 = kS4;
|
||||
static inline constexpr auto kUint4 = kU4;
|
||||
static inline constexpr auto kUint4b8 = kU4B8;
|
||||
static inline constexpr auto kInt8 = kS8;
|
||||
static inline constexpr auto kUint8 = kU8;
|
||||
static inline constexpr auto kUint8b128 = kU8B128;
|
||||
|
||||
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
||||
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
||||
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
||||
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
|
||||
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
|
||||
|
||||
// colloquial names
|
||||
static inline constexpr auto kHalf = kFE5M10;
|
||||
static inline constexpr auto kFloat16 = kHalf;
|
||||
static inline constexpr auto kBFloat16 = kFE8M7;
|
||||
|
||||
}; // namespace vllm
|
16
csrc/core/torch_bindings.cpp
Normal file
16
csrc/core/torch_bindings.cpp
Normal file
@ -0,0 +1,16 @@
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "scalar_type.hpp"
|
||||
#include "registration.h"
|
||||
|
||||
// Note the CORE exstension will be built for (almost) all hardware targets so
|
||||
// new additions must account for this. (currently not built for TPU and Neuron)
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) {
|
||||
// ScalarType, a custom class for representing data types that supports
|
||||
// quantized types, declared here so it can be used when creating interfaces
|
||||
// for custom ops.
|
||||
vllm::ScalarTypeTorch::bind_class(lib);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
@ -1,6 +1,6 @@
|
||||
#include "cache.h"
|
||||
#include "ops.h"
|
||||
#include "registration.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "registration.h"
|
||||
#include "core/registration.h"
|
||||
#include "moe_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
|
@ -3,6 +3,8 @@
|
||||
#include <optional>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
@ -84,14 +86,16 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce);
|
||||
|
@ -21,6 +21,7 @@
|
||||
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
@ -71,14 +72,15 @@ __global__ void Marlin(
|
||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||
) {}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full) {
|
||||
bool is_k_full, bool has_zp) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
@ -1963,18 +1965,29 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
|
||||
void* s, void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||
int num_bits, bool has_act_order, bool is_k_full,
|
||||
bool has_zp, int num_groups, int group_size, int dev,
|
||||
cudaStream_t stream, int thread_k, int thread_n, int sms,
|
||||
int max_par, bool use_fp32_reduce) {
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
|
||||
int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||
int sms, int max_par, bool use_fp32_reduce) {
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
|
||||
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
// TODO: remove alias when we start supporting other 8bit types
|
||||
int num_bits = q_type.size_bits();
|
||||
int tot_m = prob_m;
|
||||
int tot_m_blocks = div_ceil(tot_m, 16);
|
||||
int pad = 16 * tot_m_blocks - tot_m;
|
||||
@ -2126,19 +2139,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce) {
|
||||
// Verify num_bits
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int pack_factor = 32 / num_bits;
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ",
|
||||
b_q_type->str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
b_q_type->str());
|
||||
}
|
||||
|
||||
int pack_factor = 32 / b_q_type->size_bits();
|
||||
|
||||
// Verify A
|
||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||
@ -2265,21 +2287,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
marlin::marlin_mm_f16i4<half>(
|
||||
marlin::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||
marlin::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else {
|
||||
|
@ -27,6 +27,7 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "common/base.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
@ -86,7 +87,8 @@ __global__ void Marlin_24(
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
@ -1025,13 +1027,14 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k) {
|
||||
// Verify num_bits
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int pack_factor = 32 / num_bits;
|
||||
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
||||
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str());
|
||||
int pack_factor = 32 / b_q_type->size_bits();
|
||||
|
||||
// Verify M
|
||||
TORCH_CHECK(size_m == a.size(0),
|
||||
@ -1126,8 +1129,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
marlin_24::marlin_cuda_2_4(
|
||||
a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
|
||||
b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
|
||||
num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
|
||||
thread_m, sms, max_par);
|
||||
b_q_type->size_bits(), groupsize, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par);
|
||||
|
||||
return c;
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include "cache.h"
|
||||
#include "cuda_utils.h"
|
||||
#include "ops.h"
|
||||
#include "registration.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
|
9
setup.py
9
setup.py
@ -271,6 +271,10 @@ def _build_custom_ops() -> bool:
|
||||
return _is_cuda() or _is_hip() or _is_cpu()
|
||||
|
||||
|
||||
def _build_core_ext() -> bool:
|
||||
return not _is_neuron() and not _is_tpu()
|
||||
|
||||
|
||||
def get_hipcc_rocm_version():
|
||||
# Run the hipcc --version command
|
||||
result = subprocess.run(['hipcc', '--version'],
|
||||
@ -433,6 +437,9 @@ def get_requirements() -> List[str]:
|
||||
|
||||
ext_modules = []
|
||||
|
||||
if _build_core_ext():
|
||||
ext_modules.append(CMakeExtension(name="vllm._core_C"))
|
||||
|
||||
if _is_cuda() or _is_hip():
|
||||
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
|
||||
|
||||
@ -477,7 +484,7 @@ setup(
|
||||
extras_require={
|
||||
"tensorizer": ["tensorizer>=2.9.0"],
|
||||
},
|
||||
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
|
||||
cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
|
||||
package_data=package_data,
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
|
@ -1,8 +1,6 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# ruff: noqa: F401
|
||||
import vllm._C
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
|
||||
from vllm._custom_ops import scaled_int8_quant
|
||||
|
||||
|
@ -9,14 +9,14 @@ from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
from vllm.model_executor.layers.quantization.qqq import (
|
||||
MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
|
||||
MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS,
|
||||
marlin_make_empty_g_idx, marlin_permute_scales)
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
||||
marlin_permute_scales, query_marlin_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
pack_fp8_to_int32)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
@ -27,8 +27,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501
|
||||
marlin_qqq_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp,
|
||||
sort_weights)
|
||||
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
@ -65,12 +64,13 @@ def rand_data(shape, dtype=torch.float16):
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||
mnk_factors):
|
||||
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
act_order, mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
@ -95,11 +95,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits,
|
||||
group_size, act_order)
|
||||
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
|
||||
# Pack to GPTQ format
|
||||
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||
# increasing
|
||||
@ -108,8 +108,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(num_bits)
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
|
||||
weight_perm)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.gptq_marlin_repack(
|
||||
@ -117,7 +118,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||
sort_indices,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
quant_type.size_bits,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -128,10 +129,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
@ -150,22 +152,25 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
# Quantize
|
||||
w_ref, q_w, s, zp = quantize_weights_with_zp(b_weight, num_bits,
|
||||
group_size)
|
||||
w_ref, q_w, s, zp = quantize_weights(b_weight,
|
||||
quant_type,
|
||||
group_size,
|
||||
zero_points=True)
|
||||
|
||||
# Pack to AWQ format
|
||||
q_w_awq = awq_pack(q_w, num_bits, size_k, size_n)
|
||||
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(num_bits)
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
|
||||
weight_perm)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.awq_marlin_repack(
|
||||
q_w_awq,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
quant_type.size_bits,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -176,7 +181,8 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@ -185,7 +191,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
|
||||
def test_gptq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
act_order,
|
||||
@ -211,7 +217,7 @@ def test_gptq_marlin_gemm(
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, num_bits, group_size, act_order)
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
|
||||
@ -226,7 +232,7 @@ def test_gptq_marlin_gemm(
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
num_bits,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
@ -248,10 +254,10 @@ def test_gptq_marlin_gemm(
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size,
|
||||
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
@ -266,7 +272,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size,
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size)
|
||||
marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
|
||||
|
||||
workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||
@ -279,7 +285,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
workspace_24.scratch,
|
||||
num_bits,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
@ -371,14 +377,15 @@ def test_fp8_marlin_gemm(
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(True))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
def test_awq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
use_fp32_reduce,
|
||||
@ -396,7 +403,7 @@ def test_awq_marlin_gemm(
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, num_bits, group_size)
|
||||
b_weight, quant_type, group_size)
|
||||
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
@ -414,7 +421,7 @@ def test_awq_marlin_gemm(
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
num_bits,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
|
36
tests/test_scalartype.py
Normal file
36
tests/test_scalartype.py
Normal file
@ -0,0 +1,36 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type_tuple", (
|
||||
(-8, 7, scalar_types.int4),
|
||||
(0, 15, scalar_types.uint4),
|
||||
(-8, 7, scalar_types.uint4b8),
|
||||
(-128, 127, scalar_types.uint8b128),
|
||||
(-28., 28., scalar_types.float6_e3m2f),
|
||||
(torch.int8, scalar_types.int8),
|
||||
(torch.uint8, scalar_types.uint8),
|
||||
(torch.float8_e5m2, scalar_types.float8_e5m2),
|
||||
(torch.float8_e4m3fn, scalar_types.float8_e4m3fn),
|
||||
(torch.bfloat16, scalar_types.float16_e8m7),
|
||||
(torch.float16, scalar_types.float16_e5m10),
|
||||
),
|
||||
ids=lambda x: str(x))
|
||||
def test_scalar_type_min_max(type_tuple):
|
||||
print(type_tuple)
|
||||
if len(type_tuple) == 3:
|
||||
min, max, t = type_tuple
|
||||
else:
|
||||
torch_type, t = type_tuple
|
||||
if torch_type.is_floating_point:
|
||||
min = torch.finfo(torch_type).min
|
||||
max = torch.finfo(torch_type).max
|
||||
else:
|
||||
min = torch.iinfo(torch_type).min
|
||||
max = torch.iinfo(torch_type).max
|
||||
|
||||
print(t, min, max, t.min(), t.max())
|
||||
assert min == t.min()
|
||||
assert max == t.max()
|
177
vllm/_core_ext.py
Normal file
177
vllm/_core_ext.py
Normal file
@ -0,0 +1,177 @@
|
||||
import importlib.util
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None
|
||||
|
||||
|
||||
# Mirrors enum in `core/scalar_type.hpp`
|
||||
class NanRepr(Enum):
|
||||
NONE = 0 # nans are not supported
|
||||
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
||||
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
||||
|
||||
|
||||
if TYPE_CHECKING or not core_C_available:
|
||||
# On platforms were we cannot use/build the C++ core extension (i.e. namely
|
||||
# neuron and tpu), we define the mock ScalarType class here that partially
|
||||
# mimics the C++ ScalarType class.
|
||||
#
|
||||
# We also use this provide type signatures to the Python LSP for the methods
|
||||
# in the C++ ScalarType class. So these type signatures should be kept
|
||||
# in sync with csrc/core/scalar_type.hpp
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScalarType:
|
||||
"""
|
||||
ScalarType can represent a wide range of floating point and integer
|
||||
types, in particular it can be used to represent sub-byte data types
|
||||
(something that torch.dtype currently does not support). It is also
|
||||
capable of representing types with a bias, i.e.:
|
||||
`stored_value = value + bias`,
|
||||
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
||||
of 8). The implementation for this class can be found in
|
||||
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
||||
with that file.
|
||||
"""
|
||||
|
||||
exponent: int
|
||||
"""
|
||||
Number of bits in the exponent if this is a floating point type
|
||||
(zero if this an integer type)
|
||||
"""
|
||||
|
||||
mantissa: int
|
||||
"""
|
||||
Number of bits in the mantissa if this is a floating point type,
|
||||
or the number bits representing an integer excluding the sign bit if
|
||||
this an integer type.
|
||||
"""
|
||||
|
||||
bias: int
|
||||
"""
|
||||
bias used to encode the values in this scalar type
|
||||
(value = stored_value - bias, default 0) for example if we store the
|
||||
type as an unsigned integer with a bias of 128 then the value 0 will be
|
||||
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
||||
"""
|
||||
|
||||
signed: bool
|
||||
"If the type is signed (i.e. has a sign bit)"
|
||||
|
||||
_finite_values_only: bool = False
|
||||
"""
|
||||
Private: if NANs are supported, used `has_infs()` instead.
|
||||
"""
|
||||
|
||||
nan_repr: int = NanRepr.IEEE_754.value
|
||||
"""
|
||||
How NaNs are represent in this scalar type, returns NanRepr value.
|
||||
(not applicable for integer types)
|
||||
"""
|
||||
|
||||
@property
|
||||
def size_bits(self):
|
||||
return self.exponent + self.mantissa + int(self.signed)
|
||||
|
||||
def min(self) -> Union[int, float]:
|
||||
"""
|
||||
Min representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def max(self) -> Union[int, float]:
|
||||
"""
|
||||
Max representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_signed(self) -> bool:
|
||||
"""
|
||||
If the type is signed (i.e. has a sign bit), same as `signed`
|
||||
added for consistency with:
|
||||
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
||||
"""
|
||||
...
|
||||
|
||||
def is_floating_point(self):
|
||||
"If the type is a floating point type"
|
||||
return self.exponent != 0
|
||||
|
||||
def is_integer(self):
|
||||
"If the type is an integer type"
|
||||
return self.exponent == 0
|
||||
|
||||
def has_bias(self):
|
||||
"If the type has a non-zero bias"
|
||||
return self.bias != 0
|
||||
|
||||
def has_infs(self):
|
||||
"If the type is floating point and supports infinity"
|
||||
return not self._finite_values_only
|
||||
|
||||
def has_nans(self):
|
||||
return self.nan_repr != NanRepr.NONE.value
|
||||
|
||||
def is_ieee_754(self) -> bool:
|
||||
"""
|
||||
If the type is a floating point type that follows IEEE 754
|
||||
conventions
|
||||
"""
|
||||
return self.nan_repr == NanRepr.IEEE_754.value and \
|
||||
not self._finite_values_only
|
||||
|
||||
def __str__(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
#
|
||||
# Convenience Constructors
|
||||
#
|
||||
|
||||
@classmethod
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
||||
return cls(size_bits - 1, size_bits, bias if bias else 0, True)
|
||||
|
||||
@classmethod
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
"""Create a unsigned integer scalar type."""
|
||||
return cls(size_bits, size_bits, bias if bias else 0, False)
|
||||
|
||||
@classmethod
|
||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
||||
"""
|
||||
Create a standard floating point type
|
||||
(i.e. follows IEEE 754 conventions).
|
||||
"""
|
||||
return cls(exponent, mantissa, 0, True)
|
||||
|
||||
@classmethod
|
||||
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
||||
nan_repr: int):
|
||||
"""
|
||||
Create a non-standard floating point type
|
||||
(i.e. does not follow IEEE 754 conventions).
|
||||
"""
|
||||
return cls(exponent, mantissa, 0, True, finite_values_only,
|
||||
nan_repr)
|
||||
|
||||
elif core_C_available:
|
||||
try:
|
||||
import vllm._core_C # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._core_C with %r", e)
|
||||
|
||||
ScalarType = torch.classes._core_C.ScalarType
|
@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._core_ext import ScalarType
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -220,10 +221,10 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
# marlin_24
|
||||
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
||||
workspace: torch.Tensor, num_bits: int, size_m: int,
|
||||
size_n: int, size_k: int) -> torch.Tensor:
|
||||
workspace: torch.Tensor, b_q_type: ScalarType,
|
||||
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
|
||||
workspace, num_bits, size_m,
|
||||
workspace, b_q_type, size_m,
|
||||
size_n, size_k)
|
||||
|
||||
|
||||
@ -279,14 +280,22 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
|
||||
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
||||
|
||||
|
||||
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor, b_zeros: torch.Tensor,
|
||||
g_idx: torch.Tensor, perm: torch.Tensor,
|
||||
workspace: torch.Tensor, num_bits: int, size_m: int,
|
||||
size_n: int, size_k: int, is_k_full: bool, has_zp: bool,
|
||||
use_fp32_reduce: bool) -> torch.Tensor:
|
||||
def gptq_marlin_gemm(a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
b_zeros: torch.Tensor,
|
||||
g_idx: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
b_q_type: ScalarType,
|
||||
size_m: int,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
is_k_full: bool,
|
||||
has_zp: bool = False,
|
||||
use_fp32_reduce: bool = False) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
|
||||
g_idx, perm, workspace, num_bits,
|
||||
g_idx, perm, workspace, b_q_type,
|
||||
size_m, size_n, size_k, is_k_full,
|
||||
has_zp, use_fp32_reduce)
|
||||
|
||||
|
@ -10,11 +10,11 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points,
|
||||
check_awq_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
|
||||
marlin_permute_scales, replace_tensor, verify_awq_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
||||
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -22,20 +22,31 @@ logger = init_logger(__name__)
|
||||
class AWQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for AWQ Marlin"""
|
||||
|
||||
# num_bits -> type
|
||||
TYPE_MAP = {
|
||||
4: scalar_types.uint4,
|
||||
8: scalar_types.uint8,
|
||||
}
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
|
||||
lm_head_quantized: bool) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.pack_factor = 32 // self.weight_bits # packed into 32bits
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.has_zp = has_zp
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
|
||||
verify_awq_marlin_supported(num_bits=self.weight_bits,
|
||||
group_size=self.group_size,
|
||||
has_zp=self.has_zp)
|
||||
if weight_bits not in self.TYPE_MAP:
|
||||
raise ValueError(f"Unsupported num_bits = {weight_bits}. "
|
||||
f"Supported num_bits = {self.TYPE_MAP.keys()}")
|
||||
|
||||
self.quant_type = self.TYPE_MAP[weight_bits]
|
||||
|
||||
verify_marlin_supported(self.quant_type,
|
||||
group_size=self.group_size,
|
||||
has_zp=self.has_zp)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQMarlinConfig(weight_bits={self.weight_bits}, "
|
||||
return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"has_zp={self.has_zp}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
@ -110,11 +121,13 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
if (num_bits is None or group_size is None or has_zp is None):
|
||||
return False
|
||||
|
||||
return check_awq_marlin_supported(
|
||||
num_bits=num_bits,
|
||||
group_size=group_size,
|
||||
has_zp=has_zp,
|
||||
min_capability=cls.get_min_capability())
|
||||
if num_bits not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
|
||||
group_size=group_size,
|
||||
has_zp=has_zp,
|
||||
min_capability=cls.get_min_capability())
|
||||
|
||||
|
||||
class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
@ -226,7 +239,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
layer.qweight,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
num_bits=self.quant_config.quant_type.size_bits)
|
||||
replace_tensor(layer, "qweight", marlin_qweight)
|
||||
|
||||
# Permute scales from AWQ format to marlin format.
|
||||
@ -242,7 +255,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
layer.qzeros,
|
||||
size_k=layer.num_groups,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
num_bits=self.quant_config.quant_type.size_bits)
|
||||
replace_tensor(layer, "qzeros", marlin_zp)
|
||||
|
||||
# Not-used
|
||||
@ -263,7 +276,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
g_idx=layer.g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
quant_type=self.quant_config.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
@ -9,9 +9,13 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
__all__ = ["CompressedTensorsW4A16Sparse24"]
|
||||
W4A16SPARSE24_SUPPORTED_BITS = [4]
|
||||
W4A16SPARSE24_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.uint4b8,
|
||||
}
|
||||
W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
@ -22,9 +26,15 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
group_size: Optional[int] = None):
|
||||
self.strategy = strategy
|
||||
self.group_size = group_size
|
||||
self.num_bits = num_bits
|
||||
self.tile_size = 16
|
||||
|
||||
if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}")
|
||||
|
||||
self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
if self.strategy == "group" and self.group_size is None:
|
||||
raise ValueError(
|
||||
"group_size must be given when using strategy group")
|
||||
@ -43,7 +53,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
pack_factor = 32 // self.num_bits
|
||||
pack_factor = 32 // self.quant_type.size_bits
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
qweight = Parameter(
|
||||
@ -138,7 +148,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
|
||||
workspace, self.num_bits, size_m,
|
||||
workspace, self.quant_type, size_m,
|
||||
size_n, size_k)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
@ -8,12 +8,17 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
|
||||
marlin_permute_scales, replace_tensor, verify_gptq_marlin_supported,
|
||||
marlin_permute_scales, replace_tensor, verify_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
__all__ = ["CompressedTensorsWNA16"]
|
||||
WNA16_SUPPORTED_BITS = [4, 8]
|
||||
WNA16_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128,
|
||||
}
|
||||
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
@ -22,8 +27,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None):
|
||||
self.num_bits = num_bits
|
||||
self.pack_factor = 32 // self.num_bits
|
||||
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
|
||||
self.group_size: int
|
||||
@ -37,10 +42,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
else:
|
||||
self.group_size = group_size
|
||||
|
||||
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
|
||||
|
||||
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
# Verify supported on platform.
|
||||
verify_gptq_marlin_supported(num_bits=self.num_bits,
|
||||
group_size=self.group_size,
|
||||
is_sym=True)
|
||||
verify_marlin_supported(quant_type=self.quant_type,
|
||||
group_size=self.group_size)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -150,7 +161,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.num_bits)
|
||||
num_bits=self.quant_type.size_bits)
|
||||
replace_tensor(layer, "weight_packed", marlin_qweight)
|
||||
|
||||
# Permute scales from compressed-tensors format to marlin format.
|
||||
@ -172,7 +183,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
g_idx=layer.g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
num_bits=self.num_bits,
|
||||
wtype=self.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
is_k_full=True,
|
||||
|
@ -10,11 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full,
|
||||
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
|
||||
verify_gptq_marlin_supported, verify_marlin_supports_shape)
|
||||
verify_marlin_supported, verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -22,6 +23,12 @@ logger = init_logger(__name__)
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ Marlin"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
|
||||
is_sym: bool, lm_head_quantized: bool) -> None:
|
||||
if desc_act and group_size == -1:
|
||||
@ -29,20 +36,23 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.pack_factor = 32 // self.weight_bits # packed into int32
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
|
||||
if (weight_bits, is_sym) not in self.TYPE_MAP:
|
||||
raise ValueError("Unsupported quantization config: "
|
||||
f"bits={weight_bits}, sym={is_sym}")
|
||||
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
# Verify supported on platform.
|
||||
verify_gptq_marlin_supported(num_bits=self.weight_bits,
|
||||
group_size=self.group_size,
|
||||
is_sym=self.is_sym)
|
||||
verify_marlin_supported(quant_type=self.quant_type,
|
||||
group_size=self.group_size)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
|
||||
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
@ -122,11 +132,12 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
return check_gptq_marlin_supported(
|
||||
num_bits=num_bits,
|
||||
group_size=group_size,
|
||||
is_sym=sym,
|
||||
min_capability=cls.get_min_capability())
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
|
||||
group_size=group_size,
|
||||
min_capability=cls.get_min_capability())
|
||||
|
||||
|
||||
class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
@ -293,7 +304,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
num_bits=self.quant_config.quant_type.size_bits)
|
||||
replace_tensor(layer, "qweight", marlin_qweight)
|
||||
|
||||
# Permute scales from autogptq format to marlin format.
|
||||
@ -319,7 +330,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
g_idx=layer.g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
wtype=self.quant_config.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
is_k_full=layer.is_k_full,
|
||||
|
@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
||||
|
||||
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
|
||||
scalar_types.uint4b8, scalar_types.uint8b128
|
||||
]
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
|
||||
|
||||
|
||||
class GPTQMarlin24Config(QuantizationConfig):
|
||||
@ -31,14 +33,19 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
quant_type = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128,
|
||||
}.get(weight_bits)
|
||||
|
||||
self.group_size = group_size
|
||||
|
||||
# Verify
|
||||
if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
|
||||
if quant_type is None or \
|
||||
quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
|
||||
f"Marlin_24 does not support quant_type = {quant_type}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
|
||||
"are supported.")
|
||||
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
raise ValueError(
|
||||
@ -46,8 +53,10 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
self.quant_type = quant_type
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
self.pack_factor = 32 // self.quant_type.size_bits
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
@ -66,8 +75,8 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Marlin24Config(weight_bits={}, group_size={})".format(
|
||||
self.weight_bits, self.group_size)
|
||||
return "Marlin24Config(quant_type={}, group_size={})".format(
|
||||
self.quant_type, self.group_size)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
@ -279,7 +288,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
|
||||
|
||||
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
|
||||
workspace,
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.quant_type,
|
||||
size_m, size_n, size_k)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
from .quant_utils import pack_cols, unpack_cols
|
||||
|
||||
@ -13,7 +14,6 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_MAX_PARALLEL = 16
|
||||
|
||||
MARLIN_SUPPORTED_NUM_BITS = [4, 8]
|
||||
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# In case there is a performance issue with Marlin, the variable below can be
|
||||
@ -22,76 +22,70 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
USE_FP32_REDUCE_DEFAULT = True
|
||||
|
||||
|
||||
def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
|
||||
min_capability: Optional[int],
|
||||
has_zp: bool) -> Tuple[bool, Optional[str]]:
|
||||
if min_capability is not None:
|
||||
# For binary size and compile time, we don't support the same types for with and
|
||||
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
||||
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
||||
def query_marlin_supported_quant_types(has_zp: bool,
|
||||
min_capability: Optional[int] = None):
|
||||
if min_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
if device_capability < min_capability:
|
||||
return (False, "Marlin does not support device_capability = {}"
|
||||
", the min_capability required is {}".format(
|
||||
device_capability, min_capability))
|
||||
min_capability = major * 10 + minor
|
||||
|
||||
if num_bits not in MARLIN_SUPPORTED_NUM_BITS:
|
||||
return (False, "Marlin does not support weight_bits = {}. "
|
||||
"Only weight_bits = {} are supported.".format(
|
||||
num_bits, MARLIN_SUPPORTED_NUM_BITS))
|
||||
if min_capability < 80:
|
||||
return []
|
||||
|
||||
if group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
return (False, "Marlin does not support group_size = {}. Only "
|
||||
"group_sizes = {} are supported.".format(
|
||||
group_size, MARLIN_SUPPORTED_GROUP_SIZES))
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
|
||||
# to add `scalar_types.float8_e4m3fn` here
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
if not has_zp and not is_sym:
|
||||
return (False,
|
||||
"Marlin without zero_points must have symmetric quantization")
|
||||
|
||||
def _check_marlin_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if min_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
min_capability = major * 10 + minor
|
||||
|
||||
supported_types = query_marlin_supported_quant_types(
|
||||
has_zp, min_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Marlin does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"min_capability = {min_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"Marlin does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def check_gptq_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
|
||||
min_capability: int) -> bool:
|
||||
cond, _ = _check_marlin_supported(num_bits,
|
||||
group_size,
|
||||
is_sym,
|
||||
min_capability,
|
||||
has_zp=False)
|
||||
def check_marlin_supported(quant_type: ScalarType,
|
||||
group_size: int,
|
||||
has_zp: bool = False,
|
||||
min_capability: Optional[int] = None) -> bool:
|
||||
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
|
||||
min_capability)
|
||||
return cond
|
||||
|
||||
|
||||
def check_awq_marlin_supported(num_bits: int, group_size: int, has_zp: bool,
|
||||
min_capability: int) -> bool:
|
||||
cond, _ = _check_marlin_supported(num_bits,
|
||||
group_size,
|
||||
False,
|
||||
min_capability,
|
||||
has_zp=has_zp)
|
||||
return cond
|
||||
|
||||
|
||||
def verify_gptq_marlin_supported(num_bits: int, group_size: int,
|
||||
is_sym: bool) -> None:
|
||||
cond, err_msg = _check_marlin_supported(num_bits,
|
||||
group_size,
|
||||
is_sym,
|
||||
min_capability=None,
|
||||
has_zp=False)
|
||||
def verify_marlin_supported(quant_type: ScalarType,
|
||||
group_size: int,
|
||||
has_zp: bool = False) -> None:
|
||||
cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
|
||||
if not cond:
|
||||
assert err_msg is not None
|
||||
raise ValueError("GPTQ" + err_msg)
|
||||
|
||||
|
||||
def verify_awq_marlin_supported(num_bits: int, group_size: int,
|
||||
has_zp: bool) -> None:
|
||||
cond, err_msg = _check_marlin_supported(num_bits,
|
||||
group_size,
|
||||
False,
|
||||
min_capability=None,
|
||||
has_zp=has_zp)
|
||||
if not cond:
|
||||
assert err_msg is not None
|
||||
raise ValueError("AWQ" + err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
|
||||
def verify_marlin_supports_shape(output_size_per_partition: int,
|
||||
@ -245,7 +239,7 @@ def apply_gptq_marlin_linear(
|
||||
g_idx: torch.Tensor,
|
||||
g_idx_sort_indices: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
num_bits: int,
|
||||
wtype: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
is_k_full: bool,
|
||||
@ -261,7 +255,7 @@ def apply_gptq_marlin_linear(
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
workspace,
|
||||
num_bits,
|
||||
wtype,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
@ -283,7 +277,7 @@ def apply_awq_marlin_linear(
|
||||
g_idx: torch.Tensor,
|
||||
g_idx_sort_indices: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
num_bits: int,
|
||||
quant_type: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
@ -298,7 +292,7 @@ def apply_awq_marlin_linear(
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
workspace,
|
||||
num_bits,
|
||||
quant_type,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
|
@ -5,10 +5,12 @@ from typing import List
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.scalar_type import ScalarType
|
||||
|
||||
from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales,
|
||||
marlin_zero_points)
|
||||
from .quant_utils import (get_pack_factor, quantize_weights,
|
||||
quantize_weights_with_zp, sort_weights)
|
||||
from .quant_utils import (get_pack_factor, gptq_quantize_weights,
|
||||
quantize_weights, sort_weights)
|
||||
|
||||
|
||||
class MarlinWorkspace:
|
||||
@ -90,9 +92,10 @@ def get_weight_perm(num_bits: int):
|
||||
return perm
|
||||
|
||||
|
||||
def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
|
||||
act_order: bool):
|
||||
size_k, size_n = w.shape
|
||||
num_bits = quant_type.size_bits
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
@ -100,8 +103,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
assert group_size <= size_k
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
|
||||
act_order)
|
||||
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||
w, quant_type, group_size, act_order)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||
# increasing
|
||||
@ -122,7 +125,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
return res_list
|
||||
|
||||
|
||||
def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
|
||||
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType,
|
||||
group_size: int):
|
||||
size_k, size_n = w.shape
|
||||
|
||||
# Normalize group_size
|
||||
@ -135,13 +139,18 @@ def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
|
||||
num_groups = size_k // group_size
|
||||
|
||||
# Quantize with zp
|
||||
w_ref, q_w, s, zp = quantize_weights_with_zp(w, num_bits, group_size)
|
||||
w_ref, q_w, s, zp = quantize_weights(w,
|
||||
quant_type,
|
||||
group_size,
|
||||
zero_points=True)
|
||||
|
||||
# Reformat to marlin
|
||||
weight_perm = get_weight_perm(num_bits)
|
||||
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
|
||||
weight_perm)
|
||||
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
||||
marlin_zp = marlin_zero_points(zp, num_groups, size_n, num_bits)
|
||||
marlin_zp = marlin_zero_points(zp, num_groups, size_n,
|
||||
quant_type.size_bits)
|
||||
|
||||
# Create result
|
||||
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
|
||||
|
@ -6,8 +6,10 @@ from typing import List
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm.scalar_type import ScalarType
|
||||
|
||||
from .marlin_utils_test import marlin_weights
|
||||
from .quant_utils import quantize_weights
|
||||
from .quant_utils import gptq_quantize_weights
|
||||
|
||||
|
||||
# This is PyTorch implementation of main part of reorder_meta()
|
||||
@ -348,13 +350,11 @@ def check_24(w, num_rows_to_sample=50, _verbose=False):
|
||||
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
|
||||
|
||||
|
||||
def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
|
||||
def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
|
||||
assert q_24.shape == (size_k, size_n)
|
||||
|
||||
# Remove zp to normalize over 0
|
||||
max_q_val = (1 << num_bits) - 1
|
||||
zp = (max_q_val + 1) // 2
|
||||
q_24_no_zp = q_24 - zp
|
||||
# Remove bias to normalize over 0
|
||||
q_24_no_zp = q_24 - wtype.bias
|
||||
|
||||
# Compress
|
||||
q_24_no_zp = q_24_no_zp.t().contiguous()
|
||||
@ -362,8 +362,8 @@ def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
|
||||
q_24_no_zp)
|
||||
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
|
||||
|
||||
# Restore zp
|
||||
q_24_comp = q_24_no_zp_comp + zp
|
||||
# Restore bias
|
||||
q_24_comp = q_24_no_zp_comp + wtype.bias
|
||||
|
||||
# Resize meta to its actual shape (without moving any data)
|
||||
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
|
||||
@ -427,7 +427,7 @@ def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int,
|
||||
|
||||
def marlin_24_quantize(
|
||||
w: torch.Tensor,
|
||||
num_bits: int,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
):
|
||||
size_k, size_n = w.shape
|
||||
@ -441,20 +441,18 @@ def marlin_24_quantize(
|
||||
w_24, mask_24 = inject_24(w, size_k, size_n)
|
||||
|
||||
# Quantize
|
||||
w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
|
||||
num_bits,
|
||||
group_size,
|
||||
act_order=False)
|
||||
w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||
w_24, quant_type, group_size, act_order=False)
|
||||
|
||||
# Compress quantized weight
|
||||
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
|
||||
num_bits)
|
||||
quant_type)
|
||||
size_k_comp = size_k // 2
|
||||
|
||||
# Reformat to marlin
|
||||
weight_perm = get_weight_perm_24(num_bits)
|
||||
weight_perm = get_weight_perm_24(quant_type.size_bits)
|
||||
marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
|
||||
num_bits, weight_perm)
|
||||
quant_type.size_bits, weight_perm)
|
||||
marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
|
||||
|
||||
# Create result
|
||||
|
@ -4,7 +4,11 @@ from typing import List
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
SUPPORTED_NUM_BITS = [4, 8]
|
||||
from vllm.model_executor.layers.quantization.qqq import (
|
||||
MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# Note: this is a hack. We should update each model to register the
|
||||
@ -45,7 +49,7 @@ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
||||
|
||||
|
||||
def get_pack_factor(num_bits):
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
@ -74,24 +78,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
|
||||
)
|
||||
|
||||
|
||||
def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
act_order: bool):
|
||||
def quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
zero_points: bool = False):
|
||||
assert quant_type.is_integer(), \
|
||||
"Floating point quantization may work but has not been tested"
|
||||
|
||||
orig_device = w.device
|
||||
orig_type = w.dtype
|
||||
size_k, size_n = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||
size_k
|
||||
], f"Unsupported groupsize = {group_size}"
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
max_q_val = 2**num_bits - 1
|
||||
half_q_val = (max_q_val + 1) // 2
|
||||
|
||||
# Reshape to [groupsize, -1]
|
||||
if group_size < size_k:
|
||||
w = w.reshape((-1, group_size, size_n))
|
||||
@ -99,16 +102,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
w = w.reshape((group_size, -1))
|
||||
|
||||
# Compute scale for each group
|
||||
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
||||
s *= 2 / max_q_val # 2 => symmetric
|
||||
max_val = torch.max(w, 0, keepdim=True).values
|
||||
min_val = torch.min(w, 0, keepdim=True).values
|
||||
|
||||
max_q_val = quant_type.max()
|
||||
min_q_val = quant_type.min()
|
||||
|
||||
if zero_points:
|
||||
assert not quant_type.is_signed() and quant_type.max() > 0
|
||||
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
||||
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
|
||||
.clamp(min_q_val, max_q_val).int()
|
||||
else:
|
||||
# If the bias is such that there are no possible negative/positive
|
||||
# values, set the max value to inf to avoid divide by 0
|
||||
w_s = torch.max(
|
||||
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
||||
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
|
||||
maybe_w_zp = None
|
||||
|
||||
# Quantize
|
||||
q_w = torch.round(w / s).int()
|
||||
q_w += half_q_val
|
||||
q_w = torch.clamp(q_w, 0, max_q_val)
|
||||
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
||||
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
||||
|
||||
# Compute ref (dequantized)
|
||||
w_ref = (q_w - half_q_val).half() * s
|
||||
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
||||
|
||||
if quant_type.has_bias():
|
||||
w_q += quant_type.bias
|
||||
|
||||
# Restore original shapes
|
||||
if group_size < size_k:
|
||||
@ -119,10 +140,35 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
w = w.reshape((size_k, size_n)).contiguous()
|
||||
return w
|
||||
|
||||
q_w = reshape_w(q_w)
|
||||
w_q = reshape_w(w_q)
|
||||
w_ref = reshape_w(w_ref)
|
||||
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
w_s = w_s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
if zero_points:
|
||||
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
||||
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
w_q.to(device=orig_device),
|
||||
w_s.to(device=orig_device),
|
||||
maybe_w_zp,
|
||||
)
|
||||
|
||||
|
||||
def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
|
||||
group_size: int, act_order: bool):
|
||||
size_k, _ = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \
|
||||
f"Unsupported gptq type = {quant_type}"
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||
size_k
|
||||
], f"Unsupported groupsize = {group_size}"
|
||||
|
||||
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
|
||||
|
||||
# Apply act_order
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
|
||||
@ -133,76 +179,9 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
||||
group_size, size_k)
|
||||
|
||||
w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
|
||||
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
q_w.to(device=orig_device),
|
||||
s.to(device=orig_device),
|
||||
g_idx.to(device=orig_device),
|
||||
rand_perm.to(device=orig_device),
|
||||
)
|
||||
|
||||
|
||||
def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
|
||||
orig_device = w.device
|
||||
size_k, size_n = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||
size_k
|
||||
], f"Unsupported groupsize = {group_size}"
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
max_q_val = 2**num_bits - 1
|
||||
min_q_val = 0
|
||||
|
||||
# Reshape to [groupsize, -1]
|
||||
if group_size < size_k:
|
||||
w = w.reshape((-1, group_size, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((group_size, -1))
|
||||
|
||||
# Compute scale for each group
|
||||
max = torch.max(w, 0, keepdim=True)[0]
|
||||
min = torch.min(w, 0, keepdim=True)[0]
|
||||
s = (max - min).clamp(min=1e-5) / max_q_val
|
||||
|
||||
# Compute zero-point for each group
|
||||
zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int()
|
||||
|
||||
# Quantize
|
||||
q_w = torch.round(w / s).int() + zp
|
||||
q_w = torch.clamp(q_w, min_q_val, max_q_val)
|
||||
|
||||
# Compute ref (dequantized)
|
||||
w_ref = (q_w - zp).half() * s
|
||||
|
||||
# Restore original shapes
|
||||
if group_size < size_k:
|
||||
|
||||
def reshape_w(w):
|
||||
w = w.reshape((group_size, -1, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((size_k, size_n)).contiguous()
|
||||
return w
|
||||
|
||||
q_w = reshape_w(q_w)
|
||||
w_ref = reshape_w(w_ref)
|
||||
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
zp = zp.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
q_w.to(device=orig_device),
|
||||
s.to(device=orig_device),
|
||||
zp.to(device=orig_device),
|
||||
)
|
||||
return w_ref, w_q, w_s, g_idx, rand_perm
|
||||
|
||||
|
||||
# QQQ employs different quant schemes for per-group and
|
||||
@ -212,7 +191,8 @@ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
|
||||
size_k, size_n = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \
|
||||
f"Unsupported num_bits = {num_bits}"
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||
size_k
|
||||
], f"Unsupported groupsize = {group_size}"
|
||||
|
35
vllm/scalar_type.py
Normal file
35
vllm/scalar_type.py
Normal file
@ -0,0 +1,35 @@
|
||||
from ._core_ext import NanRepr, ScalarType
|
||||
|
||||
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
# for floating point types (leading f) the scheme is:
|
||||
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
# flags:
|
||||
# - no-flags: means it follows IEEE 754 conventions
|
||||
# - f: means finite values only (no infinities)
|
||||
# - n: means nans are supported (non-standard encoding)
|
||||
# for integer types the scheme is:
|
||||
# `[u]int<size_bits>[b<bias>]`
|
||||
# - if bias is not present it means its zero
|
||||
|
||||
|
||||
class scalar_types:
|
||||
int4 = ScalarType.int_(4, None)
|
||||
uint4 = ScalarType.uint(4, None)
|
||||
int8 = ScalarType.int_(8, None)
|
||||
uint8 = ScalarType.uint(8, None)
|
||||
float8_e4m3fn = ScalarType.float_(4, 3, True,
|
||||
NanRepr.EXTD_RANGE_MAX_MIN.value)
|
||||
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
||||
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
||||
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
||||
|
||||
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
||||
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)
|
||||
|
||||
# "gptq" types
|
||||
uint4b8 = ScalarType.uint(4, 8)
|
||||
uint8b128 = ScalarType.uint(8, 128)
|
||||
|
||||
# colloquial names
|
||||
bfloat16 = float16_e8m7
|
||||
float16 = float16_e5m10
|
Loading…
x
Reference in New Issue
Block a user