[Bugfix] Fix support for dimension like integers and ScalarType (#9299)
This commit is contained in:
parent
0f41fbe5a3
commit
eca2c5f7c0
@ -230,14 +230,12 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s compile/test_basic_correctness.py
|
- pytest -v -s compile/test_basic_correctness.py
|
||||||
|
|
||||||
# TODO: re-write in comparison tests, and fix symbolic shape
|
- label: "PyTorch Fullgraph Test" # 18min
|
||||||
# for quantization ops.
|
source_file_dependencies:
|
||||||
# - label: "PyTorch Fullgraph Test" # 18min
|
- vllm/
|
||||||
# source_file_dependencies:
|
- tests/compile
|
||||||
# - vllm/
|
commands:
|
||||||
# - tests/compile
|
- pytest -v -s compile/test_full_graph.py
|
||||||
# commands:
|
|
||||||
# - pytest -v -s compile/test_full_graph.py
|
|
||||||
|
|
||||||
- label: Kernels Test %N # 1h each
|
- label: Kernels Test %N # 1h each
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
|
@ -83,24 +83,6 @@ endif()
|
|||||||
#
|
#
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
|
|
||||||
#
|
|
||||||
message(STATUS "Enabling core extension.")
|
|
||||||
|
|
||||||
# Define _core_C extension
|
|
||||||
# built for (almost) every target platform, (excludes TPU and Neuron)
|
|
||||||
|
|
||||||
set(VLLM_EXT_SRC
|
|
||||||
"csrc/core/torch_bindings.cpp")
|
|
||||||
|
|
||||||
define_gpu_extension_target(
|
|
||||||
_core_C
|
|
||||||
DESTINATION vllm
|
|
||||||
LANGUAGE CXX
|
|
||||||
SOURCES ${VLLM_EXT_SRC}
|
|
||||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
|
||||||
USE_SABI 3
|
|
||||||
WITH_SOABI)
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Forward the non-CUDA device extensions to external CMake scripts.
|
# Forward the non-CUDA device extensions to external CMake scripts.
|
||||||
#
|
#
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/custom_class.h>
|
// For TORCH_CHECK
|
||||||
|
#include <torch/library.h>
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
@ -9,12 +10,7 @@ namespace vllm {
|
|||||||
// in particular it can be used to represent sub-byte data types (something
|
// in particular it can be used to represent sub-byte data types (something
|
||||||
// that torch.dtype currently does not support).
|
// that torch.dtype currently does not support).
|
||||||
//
|
//
|
||||||
// ScalarTypeTorch is a subclass of ScalarType that is compatible with
|
// The type definitions on the Python side can be found in: vllm/scalar_type.py
|
||||||
// TORCH_LIBRARY, making it accessible from Python as well meaning this class
|
|
||||||
// can be used as a argument for custom operators, helping to simplify these
|
|
||||||
// interfaces.
|
|
||||||
//
|
|
||||||
// The type definitions on the Python side can be found in: vllm/_core_ext.pyi
|
|
||||||
// these type definitions should be kept up to date with any Python API changes
|
// these type definitions should be kept up to date with any Python API changes
|
||||||
// here.
|
// here.
|
||||||
//
|
//
|
||||||
@ -308,204 +304,7 @@ class ScalarType {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from
|
using ScalarTypeId = ScalarType::Id;
|
||||||
// torch::CustomClassHolder), we use multiple inheritance here since we cannot
|
|
||||||
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
|
|
||||||
// constructor at the same time (torch::CustomClassHolder does not have a
|
|
||||||
// constexpr destructor)
|
|
||||||
// See also:
|
|
||||||
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
|
||||||
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
|
||||||
public:
|
|
||||||
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
|
|
||||||
bool _signed)
|
|
||||||
: ScalarType(exponent, mantissa, bias, _signed){};
|
|
||||||
|
|
||||||
ScalarTypeTorch(ScalarType type) : ScalarType(type){};
|
|
||||||
|
|
||||||
using Base = ScalarType;
|
|
||||||
using Self = ScalarTypeTorch;
|
|
||||||
using SelfPtr = c10::intrusive_ptr<Self>;
|
|
||||||
|
|
||||||
static void check_size_bits(int64_t size_bits, bool signed_) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
size_bits <=
|
|
||||||
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
|
|
||||||
"size_bits bit width is too large to be represented");
|
|
||||||
}
|
|
||||||
|
|
||||||
static void check_bias(int64_t bias) {
|
|
||||||
using Bias = decltype(std::declval<Self>().bias);
|
|
||||||
TORCH_CHECK(bias <= std::numeric_limits<Bias>::max() &&
|
|
||||||
bias >= std::numeric_limits<Bias>::min(),
|
|
||||||
"bias too large or small to be represented");
|
|
||||||
}
|
|
||||||
|
|
||||||
static void check_exponent(int64_t exponent) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
exponent <=
|
|
||||||
std::numeric_limits<decltype(std::declval<Self>().exponent)>::max(),
|
|
||||||
"exponent bit width is too large to be represented");
|
|
||||||
}
|
|
||||||
|
|
||||||
static void check_mantissa(int64_t mantissa) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
mantissa <=
|
|
||||||
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
|
|
||||||
"mantissa bit width is too large to be represented");
|
|
||||||
}
|
|
||||||
|
|
||||||
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
|
|
||||||
check_size_bits(size_bits, true);
|
|
||||||
check_bias(bias.value_or(0));
|
|
||||||
return c10::make_intrusive<Self>(
|
|
||||||
ScalarType::int_(size_bits, bias.value_or(0)));
|
|
||||||
}
|
|
||||||
|
|
||||||
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
|
|
||||||
check_size_bits(size_bits, true);
|
|
||||||
check_bias(bias.value_or(0));
|
|
||||||
return c10::make_intrusive<Self>(
|
|
||||||
ScalarType::uint(size_bits, bias.value_or(0)));
|
|
||||||
}
|
|
||||||
|
|
||||||
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
|
|
||||||
check_mantissa(mantissa);
|
|
||||||
check_exponent(exponent);
|
|
||||||
return c10::make_intrusive<Self>(
|
|
||||||
ScalarType::float_IEEE754(exponent, mantissa));
|
|
||||||
}
|
|
||||||
|
|
||||||
static SelfPtr float_(int64_t exponent, int64_t mantissa,
|
|
||||||
bool finite_values_only, int64_t nan_repr) {
|
|
||||||
check_mantissa(mantissa);
|
|
||||||
check_exponent(exponent);
|
|
||||||
return c10::make_intrusive<Self>(ScalarType::float_(
|
|
||||||
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// This needs to be implemented and throw a TypeError in order for
|
|
||||||
// PyTorch's opcheck to work on ops that use ScalarTypes.
|
|
||||||
int64_t len() const {
|
|
||||||
throw c10::TypeError({__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
|
|
||||||
"__len__ not implemented");
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serialize a ScalarType into a tuple of pairs. Where each pair
|
|
||||||
// is a (fieldname, value).
|
|
||||||
// For simplicity, we are just going to convert to a ScalarTypeId.
|
|
||||||
std::tuple<std::tuple<std::string, int64_t>> obj_flatten() const {
|
|
||||||
return {{"ScalarType", id()}};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deserialize a scalar type that has been serialized by obj_flatten,
|
|
||||||
// ostensibly from a tuple of (member name, value) pairs, but in reality
|
|
||||||
// just a ScalarTypeId.
|
|
||||||
static SelfPtr obj_unflatten(
|
|
||||||
std::tuple<std::tuple<std::string, int64_t>> const& flat_type) {
|
|
||||||
return c10::make_intrusive<Self>(
|
|
||||||
from_id(std::get<1>(std::get<0>(flat_type))));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static void bind_readonly_property(torch::class_<Self>& cls,
|
|
||||||
std::string const& name, T Base::*field) {
|
|
||||||
auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) {
|
|
||||||
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
|
|
||||||
return (self.get()->*field)();
|
|
||||||
} else {
|
|
||||||
return self.get()->*field;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
auto getter_func = [field = std::move(field),
|
|
||||||
getter_func_helper = std::move(getter_func_helper)](
|
|
||||||
SelfPtr const& self) {
|
|
||||||
auto val = getter_func_helper(self);
|
|
||||||
// upconvert uint8_t, int32_t etc. to int64_t for python
|
|
||||||
if constexpr (std::is_integral_v<T>) {
|
|
||||||
return static_cast<int64_t>(val);
|
|
||||||
} else {
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
cls.def_property(name, getter_func);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename MemberFunc, typename Cls>
|
|
||||||
static void bind_function(torch::class_<Self>& cls, const std::string& name,
|
|
||||||
MemberFunc Cls::*member) {
|
|
||||||
cls.def(name, [member = std::move(member)](SelfPtr const& self) {
|
|
||||||
return (self.get()->*member)();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Func>
|
|
||||||
static void bind_function(torch::class_<Self>& cls, const std::string& name,
|
|
||||||
Func func) {
|
|
||||||
cls.def(name, func);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Func>
|
|
||||||
static void bind_static_function(torch::class_<Self>& cls,
|
|
||||||
const std::string& name, Func func) {
|
|
||||||
cls.def_static(name, func);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void bind_class(torch::Library& lib) {
|
|
||||||
auto cls = lib.class_<ScalarTypeTorch>("ScalarType")
|
|
||||||
.def(torch::init<int64_t, int64_t, int64_t, bool>());
|
|
||||||
|
|
||||||
// Bind Properties
|
|
||||||
bind_readonly_property(cls, "mantissa", &Base::mantissa);
|
|
||||||
bind_readonly_property(cls, "exponent", &Base::exponent);
|
|
||||||
bind_readonly_property(cls, "bias", &Base::bias);
|
|
||||||
bind_readonly_property(cls, "signed", &Base::is_signed);
|
|
||||||
bind_readonly_property(cls, "size_bits", &Base::size_bits);
|
|
||||||
|
|
||||||
// Bind member functions
|
|
||||||
bind_function(cls, "is_signed", &Base::is_signed);
|
|
||||||
bind_function(cls, "is_integer", &Base::is_integer);
|
|
||||||
bind_function(cls, "is_floating_point", &Base::is_floating_point);
|
|
||||||
bind_function(cls, "is_ieee_754", &Base::is_ieee_754);
|
|
||||||
bind_function(cls, "has_nans", &Base::has_nans);
|
|
||||||
bind_function(cls, "has_infs", &Base::has_infs);
|
|
||||||
bind_function(cls, "has_bias", &Base::has_bias);
|
|
||||||
|
|
||||||
bind_function(cls, "max", [](SelfPtr const& self) {
|
|
||||||
return std::visit([](auto arg) { return c10::IValue(arg); },
|
|
||||||
self.get()->max());
|
|
||||||
});
|
|
||||||
bind_function(cls, "min", [](SelfPtr const& self) {
|
|
||||||
return std::visit([](auto arg) { return c10::IValue(arg); },
|
|
||||||
self.get()->min());
|
|
||||||
});
|
|
||||||
|
|
||||||
bind_function(cls, "__len__", &ScalarTypeTorch::len);
|
|
||||||
bind_function(cls, "__str__", &Base::str);
|
|
||||||
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
|
|
||||||
return *self == *other;
|
|
||||||
});
|
|
||||||
bind_function(cls, "__repr__", [](SelfPtr const& self) {
|
|
||||||
return "ScalarType." + self.get()->str();
|
|
||||||
});
|
|
||||||
|
|
||||||
bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten);
|
|
||||||
bind_static_function(cls, "__obj_unflatten__",
|
|
||||||
&ScalarTypeTorch::obj_unflatten);
|
|
||||||
|
|
||||||
// Bind static functions (convenience constructors)
|
|
||||||
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
|
|
||||||
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
|
|
||||||
bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754);
|
|
||||||
bind_static_function(cls, "float_", &ScalarTypeTorch::float_);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
using ScalarTypeId = int64_t;
|
|
||||||
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
|
|
||||||
|
|
||||||
// "rust style" names generally following:
|
// "rust style" names generally following:
|
||||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
||||||
|
@ -1,16 +0,0 @@
|
|||||||
#include <torch/library.h>
|
|
||||||
|
|
||||||
#include "scalar_type.hpp"
|
|
||||||
#include "registration.h"
|
|
||||||
|
|
||||||
// Note the CORE exstension will be built for (almost) all hardware targets so
|
|
||||||
// new additions must account for this. (currently not built for TPU and Neuron)
|
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) {
|
|
||||||
// ScalarType, a custom class for representing data types that supports
|
|
||||||
// quantized types, declared here so it can be used when creating interfaces
|
|
||||||
// for custom ops.
|
|
||||||
vllm::ScalarTypeTorch::bind_class(lib);
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
@ -484,21 +484,22 @@ torch::Tensor marlin_gemm_moe(
|
|||||||
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
||||||
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
|
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
|
||||||
const torch::Tensor& perm, torch::Tensor& workspace,
|
const torch::Tensor& perm, torch::Tensor& workspace,
|
||||||
vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n,
|
vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
|
||||||
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
|
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
|
||||||
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
|
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
|
||||||
|
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||||
bool has_zp = b_zeros.size(1) != 0;
|
bool has_zp = b_zeros.size(1) != 0;
|
||||||
if (has_zp) {
|
if (has_zp) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
*b_q_type == vllm::kU4,
|
b_q_type == vllm::kU4,
|
||||||
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str());
|
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||||
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());
|
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
int pack_factor = 32 / b_q_type->size_bits();
|
int pack_factor = 32 / b_q_type.size_bits();
|
||||||
|
|
||||||
int max_par = 4;
|
int max_par = 4;
|
||||||
|
|
||||||
@ -575,7 +576,7 @@ torch::Tensor marlin_gemm_moe(
|
|||||||
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
|
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
|
||||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
|
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
|
||||||
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
|
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
|
||||||
*b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
|
b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
|
||||||
num_experts, topk, moe_block_size, dev,
|
num_experts, topk, moe_block_size, dev,
|
||||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
|
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
|
||||||
replicate_input, apply_weights);
|
replicate_input, apply_weights);
|
||||||
|
@ -13,8 +13,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||||
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
|
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
|
||||||
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
|
"int b_q_type, SymInt size_m, "
|
||||||
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
|
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
|
||||||
|
"topk, "
|
||||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||||
" -> Tensor");
|
" -> Tensor");
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
@ -80,7 +80,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||||
torch::Tensor& workspace,
|
torch::Tensor& workspace,
|
||||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
vllm::ScalarTypeId const b_q_type_id,
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||||
bool is_k_full, bool has_zp) {
|
bool is_k_full, bool has_zp) {
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||||
@ -2132,22 +2132,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||||
torch::Tensor& workspace,
|
torch::Tensor& workspace,
|
||||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
vllm::ScalarTypeId const& b_q_type_id,
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||||
bool is_k_full, bool has_zp,
|
bool is_k_full, bool has_zp,
|
||||||
bool use_fp32_reduce) {
|
bool use_fp32_reduce) {
|
||||||
|
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||||
if (has_zp) {
|
if (has_zp) {
|
||||||
TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
|
TORCH_CHECK(
|
||||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ",
|
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
||||||
b_q_type->str());
|
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||||
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||||
b_q_type->str());
|
b_q_type.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
int pack_factor = 32 / b_q_type->size_bits();
|
int pack_factor = 32 / b_q_type.size_bits();
|
||||||
|
|
||||||
// Verify A
|
// Verify A
|
||||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||||
@ -2279,7 +2280,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||||
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
|
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
@ -2288,7 +2289,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||||
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
|
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||||
} else {
|
} else {
|
||||||
|
@ -38,9 +38,10 @@ static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
|
|||||||
// Interface
|
// Interface
|
||||||
//
|
//
|
||||||
|
|
||||||
std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
|
std::vector<std::string> supported_schedules(ScalarTypeId const btype_id) {
|
||||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||||
return scalar_type_dispatch(*btype, [&](auto BType) {
|
vllm::ScalarType b_type = ScalarType::from_id(btype_id);
|
||||||
|
return scalar_type_dispatch(b_type, [&](auto BType) {
|
||||||
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
|
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
|
||||||
});
|
});
|
||||||
#else
|
#else
|
||||||
@ -49,7 +50,7 @@ std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||||
ScalarTypeTorchPtr const& btype,
|
ScalarTypeId const btype_id,
|
||||||
c10::optional<torch::Tensor> const& scales,
|
c10::optional<torch::Tensor> const& scales,
|
||||||
c10::optional<torch::Tensor> const& zeros,
|
c10::optional<torch::Tensor> const& zeros,
|
||||||
c10::optional<int64_t> group_size,
|
c10::optional<int64_t> group_size,
|
||||||
@ -57,6 +58,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
|||||||
c10::optional<double> alpha, c10::optional<double> beta,
|
c10::optional<double> alpha, c10::optional<double> beta,
|
||||||
c10::optional<std::string> schedule) {
|
c10::optional<std::string> schedule) {
|
||||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||||
|
ScalarType const btype = ScalarType::from_id(btype_id);
|
||||||
auto args = PyTorchArguments{.A = A,
|
auto args = PyTorchArguments{.A = A,
|
||||||
.B = B,
|
.B = B,
|
||||||
.scales = scales,
|
.scales = scales,
|
||||||
@ -67,7 +69,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
|||||||
.beta = beta,
|
.beta = beta,
|
||||||
.schedule = schedule};
|
.schedule = schedule};
|
||||||
|
|
||||||
return scalar_type_dispatch(*btype, [&](auto BType) {
|
return scalar_type_dispatch(btype, [&](auto BType) {
|
||||||
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
|
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
|
||||||
A.scalar_type(), "machete_gemm", [&] {
|
A.scalar_type(), "machete_gemm", [&] {
|
||||||
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
|
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
|
||||||
@ -79,9 +81,9 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor prepack_B(torch::Tensor const& B,
|
torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) {
|
||||||
vllm::ScalarTypeTorchPtr const& btype) {
|
ScalarType const btype = ScalarType::from_id(btype_id);
|
||||||
return scalar_type_dispatch(*btype, [&](auto BType) {
|
return scalar_type_dispatch(btype, [&](auto BType) {
|
||||||
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
|
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -89,7 +89,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
torch::Tensor& b_meta,
|
torch::Tensor& b_meta,
|
||||||
torch::Tensor& b_scales,
|
torch::Tensor& b_scales,
|
||||||
torch::Tensor& workspace,
|
torch::Tensor& workspace,
|
||||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
vllm::ScalarTypeId const b_q_type_id,
|
||||||
int64_t size_m, int64_t size_n,
|
int64_t size_m, int64_t size_n,
|
||||||
int64_t size_k) {
|
int64_t size_k) {
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
@ -1029,13 +1029,14 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
torch::Tensor& b_meta,
|
torch::Tensor& b_meta,
|
||||||
torch::Tensor& b_scales,
|
torch::Tensor& b_scales,
|
||||||
torch::Tensor& workspace,
|
torch::Tensor& workspace,
|
||||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
vllm::ScalarTypeId const b_q_type_id,
|
||||||
int64_t size_m, int64_t size_n,
|
int64_t size_m, int64_t size_n,
|
||||||
int64_t size_k) {
|
int64_t size_k) {
|
||||||
|
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||||
// Verify num_bits
|
// Verify num_bits
|
||||||
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||||
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str());
|
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type.str());
|
||||||
int pack_factor = 32 / b_q_type->size_bits();
|
int pack_factor = 32 / b_q_type.size_bits();
|
||||||
|
|
||||||
// Verify M
|
// Verify M
|
||||||
TORCH_CHECK(size_m == a.size(0),
|
TORCH_CHECK(size_m == a.size(0),
|
||||||
@ -1130,8 +1131,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
marlin_24::marlin_cuda_2_4(
|
marlin_24::marlin_cuda_2_4(
|
||||||
a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
|
a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
|
||||||
b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
|
b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
|
||||||
b_q_type->size_bits(), groupsize, dev,
|
b_q_type.size_bits(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par);
|
thread_k, thread_m, sms, max_par);
|
||||||
|
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
@ -140,13 +140,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
// Quantized GEMM for AWQ.
|
// Quantized GEMM for AWQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
|
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
|
||||||
"Tensor _zeros, int split_k_iters) -> Tensor");
|
"Tensor _zeros, SymInt split_k_iters) -> Tensor");
|
||||||
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
|
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
|
||||||
|
|
||||||
// Dequantization for AWQ.
|
// Dequantization for AWQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
|
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
|
||||||
"Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor");
|
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
|
||||||
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
||||||
|
|
||||||
// Note about marlin kernel 'workspace' arguments:
|
// Note about marlin kernel 'workspace' arguments:
|
||||||
@ -166,32 +166,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
|
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||||
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
|
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
|
||||||
|
"Tensor");
|
||||||
// conditionally compiled so impl in source file
|
// conditionally compiled so impl in source file
|
||||||
|
|
||||||
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
||||||
"Tensor b_scales, Tensor workspace, "
|
"Tensor b_scales, Tensor workspace, "
|
||||||
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
"int b_q_type, "
|
||||||
"int size_m, int size_n, int size_k) -> Tensor");
|
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
|
||||||
// conditionally compiled so impl in source file
|
// conditionally compiled so impl in source file
|
||||||
|
|
||||||
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
||||||
|
ops.def("machete_supported_schedules(int btype) -> str[]");
|
||||||
ops.def(
|
ops.def(
|
||||||
"machete_supported_schedules("
|
"machete_gemm(Tensor A, Tensor B, int btype, "
|
||||||
" __torch__.torch.classes._core_C.ScalarType btype"
|
" Tensor? scales, Tensor? zeros, int? group_size, "
|
||||||
") -> str[]");
|
|
||||||
ops.def(
|
|
||||||
"machete_gemm(Tensor A, Tensor B,"
|
|
||||||
" __torch__.torch.classes._core_C.ScalarType btype,"
|
|
||||||
" Tensor? scales, Tensor? zeros, int? group_size,"
|
|
||||||
" Tensor? C, float? alpha, float? beta, str? schedule)"
|
" Tensor? C, float? alpha, float? beta, str? schedule)"
|
||||||
"-> Tensor");
|
"-> Tensor");
|
||||||
ops.def(
|
ops.def("machete_prepack_B(Tensor B, int btype) -> Tensor");
|
||||||
"machete_prepack_B(Tensor B,"
|
|
||||||
" __torch__.torch.classes._core_C.ScalarType btype)"
|
|
||||||
"-> Tensor");
|
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||||
@ -201,8 +195,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def(
|
ops.def(
|
||||||
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||||
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
||||||
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
"int b_q_type, "
|
||||||
"int size_m, int size_n, int size_k, bool is_k_full, "
|
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||||
"bool has_zp, bool use_fp32_reduce) -> Tensor");
|
"bool has_zp, bool use_fp32_reduce) -> Tensor");
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
@ -219,32 +213,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
// conditionally compiled so impl registrations are in source file
|
// conditionally compiled so impl registrations are in source file
|
||||||
|
|
||||||
// Dequantization for GGML.
|
// Dequantization for GGML.
|
||||||
ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
|
ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor");
|
||||||
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
|
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
|
||||||
|
|
||||||
// mmvq kernel for GGML.
|
// mmvq kernel for GGML.
|
||||||
ops.def(
|
ops.def(
|
||||||
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) "
|
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
|
||||||
"-> Tensor");
|
"-> Tensor");
|
||||||
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
|
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
|
||||||
|
|
||||||
// mmq kernel for GGML.
|
// mmq kernel for GGML.
|
||||||
ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor");
|
ops.def(
|
||||||
|
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
|
||||||
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
|
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
|
||||||
|
|
||||||
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
||||||
ops.def(
|
ops.def(
|
||||||
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||||
"Tensor! workspace, int num_bits, int size_m, int size_n, "
|
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
||||||
"int size_k) -> Tensor");
|
"SymInt size_k) -> Tensor");
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
// marlin_qqq_gemm for QQQ.
|
// marlin_qqq_gemm for QQQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
||||||
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
||||||
"Tensor! workspace, int size_m, int size_n, "
|
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
||||||
"int size_k) -> Tensor");
|
"SymInt size_k) -> Tensor");
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||||
|
@ -39,7 +39,6 @@ assert cwd != package_path, "should not import from the current directory"
|
|||||||
|
|
||||||
files_to_copy = [
|
files_to_copy = [
|
||||||
"vllm/_C.abi3.so",
|
"vllm/_C.abi3.so",
|
||||||
"vllm/_core_C.abi3.so",
|
|
||||||
"vllm/_moe_C.abi3.so",
|
"vllm/_moe_C.abi3.so",
|
||||||
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so",
|
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so",
|
||||||
"vllm/vllm_flash_attn/flash_attn_interface.py",
|
"vllm/vllm_flash_attn/flash_attn_interface.py",
|
||||||
|
7
setup.py
7
setup.py
@ -290,10 +290,6 @@ def _build_custom_ops() -> bool:
|
|||||||
return _is_cuda() or _is_hip() or _is_cpu()
|
return _is_cuda() or _is_hip() or _is_cpu()
|
||||||
|
|
||||||
|
|
||||||
def _build_core_ext() -> bool:
|
|
||||||
return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu())
|
|
||||||
|
|
||||||
|
|
||||||
def get_hipcc_rocm_version():
|
def get_hipcc_rocm_version():
|
||||||
# Run the hipcc --version command
|
# Run the hipcc --version command
|
||||||
result = subprocess.run(['hipcc', '--version'],
|
result = subprocess.run(['hipcc', '--version'],
|
||||||
@ -456,9 +452,6 @@ def get_requirements() -> List[str]:
|
|||||||
|
|
||||||
ext_modules = []
|
ext_modules = []
|
||||||
|
|
||||||
if _build_core_ext():
|
|
||||||
ext_modules.append(CMakeExtension(name="vllm._core_C"))
|
|
||||||
|
|
||||||
if _is_cuda() or _is_hip():
|
if _is_cuda() or _is_hip():
|
||||||
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
|
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
|
||||||
|
|
||||||
|
@ -69,11 +69,11 @@ def check_full_graph_support(model,
|
|||||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
|
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
|
||||||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
|
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
|
||||||
|
|
||||||
# Inductor doesn't support fp8/gptq_marlin_24 yet.
|
# Inductor doesn't support fp8 and the base meta llama uses too
|
||||||
|
# much memory.
|
||||||
quantization = model_kwargs.get("quantization")
|
quantization = model_kwargs.get("quantization")
|
||||||
if (quantization == "fp8" or quantization == "gptq_marlin"
|
if ((quantization == "fp8" or model == "meta-llama/Meta-Llama-3-8B")
|
||||||
or quantization == "gptq_marlin_24"
|
and optimization_level >= CompilationLevel.INDUCTOR):
|
||||||
) and optimization_level >= CompilationLevel.INDUCTOR:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
|
@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor,
|
|||||||
w_q = w_q.t().contiguous().t() # convert to col major
|
w_q = w_q.t().contiguous().t() # convert to col major
|
||||||
w_q_machete = ops.machete_prepack_B(w_q, wtype)
|
w_q_machete = ops.machete_prepack_B(w_q, wtype)
|
||||||
|
|
||||||
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype))
|
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id))
|
||||||
|
|
||||||
return w_ref, w_q_machete, w_s, w_zp
|
return w_ref, w_q_machete, w_s, w_zp
|
||||||
|
|
||||||
@ -153,8 +153,9 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
|
|||||||
schedule=schedule,
|
schedule=schedule,
|
||||||
)
|
)
|
||||||
|
|
||||||
opcheck(torch.ops._C.machete_gemm,
|
opcheck(
|
||||||
(a, w_q_machete, wtype, w_s, maybe_convert_zeropoints(
|
torch.ops._C.machete_gemm,
|
||||||
|
(a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints(
|
||||||
w_zp, w_s), group_size, None, None, None, schedule))
|
w_zp, w_s), group_size, None, None, None, schedule))
|
||||||
|
|
||||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||||
|
@ -225,7 +225,7 @@ def test_gptq_marlin_gemm(
|
|||||||
opcheck(
|
opcheck(
|
||||||
torch.ops._C.gptq_marlin_gemm,
|
torch.ops._C.gptq_marlin_gemm,
|
||||||
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
||||||
workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1],
|
workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
|
||||||
a_input.shape[1], is_k_full, False, use_fp32_reduce),
|
a_input.shape[1], is_k_full, False, use_fp32_reduce),
|
||||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||||
|
|
||||||
@ -254,6 +254,16 @@ def test_gptq_marlin_gemm(
|
|||||||
assert max_diff < 0.04
|
assert max_diff < 0.04
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: find better way to test this?
|
||||||
|
@torch.compile(fullgraph=True)
|
||||||
|
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
|
||||||
|
marlin_24_s, scratch, quant_type, size_m, size_n,
|
||||||
|
size_k):
|
||||||
|
return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
|
||||||
|
marlin_24_s, scratch, quant_type, size_m,
|
||||||
|
size_n, size_k)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||||
reason="Marlin is not supported on this GPU type.")
|
reason="Marlin is not supported on this GPU type.")
|
||||||
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
||||||
@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
|||||||
|
|
||||||
opcheck(torch.ops._C.gptq_marlin_24_gemm,
|
opcheck(torch.ops._C.gptq_marlin_24_gemm,
|
||||||
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
|
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
|
||||||
workspace_24.scratch, quant_type, a_input.shape[0],
|
workspace_24.scratch, quant_type.id, a_input.shape[0],
|
||||||
b_weight.shape[1], a_input.shape[1]),
|
b_weight.shape[1], a_input.shape[1]),
|
||||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||||
|
|
||||||
output = ops.gptq_marlin_24_gemm(
|
output = marlin_24_gemm_tester(
|
||||||
a_input,
|
a_input,
|
||||||
marlin_24_q_w_comp,
|
marlin_24_q_w_comp,
|
||||||
marlin_24_meta,
|
marlin_24_meta,
|
||||||
|
@ -240,8 +240,8 @@ def test_fused_marlin_moe(
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
opcheck(torch.ops._moe_C.marlin_gemm_moe,
|
opcheck(torch.ops._moe_C.marlin_gemm_moe,
|
||||||
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
|
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
|
||||||
scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m,
|
scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
|
||||||
2 * n, k, True, e, topk, block_size_m, True, False))
|
m, 2 * n, k, True, e, topk, block_size_m, True, False))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
@pytest.mark.skip("This test is here for the sake of debugging, "
|
||||||
|
@ -32,5 +32,5 @@ def test_scalar_type_min_max(type_tuple):
|
|||||||
max = torch.iinfo(torch_type).max
|
max = torch.iinfo(torch_type).max
|
||||||
|
|
||||||
print(t, min, max, t.min(), t.max())
|
print(t, min, max, t.min(), t.max())
|
||||||
assert min == t.min()
|
assert min == t.min(), f"min: {min} != {t.min()}"
|
||||||
assert max == t.max()
|
assert max == t.max(), f"max: {max} != {t.max()}"
|
||||||
|
@ -16,7 +16,6 @@ Typical output looks like this:
|
|||||||
2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time)
|
2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time)
|
||||||
3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time)
|
3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time)
|
||||||
Longest build steps for .so (linking):
|
Longest build steps for .so (linking):
|
||||||
0.1 weighted s to build _core_C.abi3.so (0.7 s elapsed time)
|
|
||||||
0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time)
|
0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time)
|
||||||
0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time)
|
0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time)
|
||||||
6.2 weighted s to build _C.abi3.so (6.2 s elapsed time)
|
6.2 weighted s to build _C.abi3.so (6.2 s elapsed time)
|
||||||
|
@ -1,278 +0,0 @@
|
|||||||
import importlib.util
|
|
||||||
from enum import Enum
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None
|
|
||||||
|
|
||||||
|
|
||||||
# Mirrors enum in `core/scalar_type.hpp`
|
|
||||||
class NanRepr(Enum):
|
|
||||||
NONE = 0 # nans are not supported
|
|
||||||
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
|
||||||
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING or not core_C_available:
|
|
||||||
# On platforms were we cannot use/build the C++ core extension (i.e. namely
|
|
||||||
# neuron and tpu), we define the mock ScalarType class here that partially
|
|
||||||
# mimics the C++ ScalarType class.
|
|
||||||
#
|
|
||||||
# We also use this provide type signatures to the Python LSP for the methods
|
|
||||||
# in the C++ ScalarType class. So these type signatures should be kept
|
|
||||||
# in sync with csrc/core/scalar_type.hpp
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ScalarType:
|
|
||||||
"""
|
|
||||||
ScalarType can represent a wide range of floating point and integer
|
|
||||||
types, in particular it can be used to represent sub-byte data types
|
|
||||||
(something that torch.dtype currently does not support). It is also
|
|
||||||
capable of representing types with a bias, i.e.:
|
|
||||||
`stored_value = value + bias`,
|
|
||||||
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
|
||||||
of 8). The implementation for this class can be found in
|
|
||||||
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
|
||||||
with that file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
exponent: int
|
|
||||||
"""
|
|
||||||
Number of bits in the exponent if this is a floating point type
|
|
||||||
(zero if this an integer type)
|
|
||||||
"""
|
|
||||||
|
|
||||||
mantissa: int
|
|
||||||
"""
|
|
||||||
Number of bits in the mantissa if this is a floating point type,
|
|
||||||
or the number bits representing an integer excluding the sign bit if
|
|
||||||
this an integer type.
|
|
||||||
"""
|
|
||||||
|
|
||||||
bias: int
|
|
||||||
"""
|
|
||||||
bias used to encode the values in this scalar type
|
|
||||||
(value = stored_value - bias, default 0) for example if we store the
|
|
||||||
type as an unsigned integer with a bias of 128 then the value 0 will be
|
|
||||||
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
|
||||||
"""
|
|
||||||
|
|
||||||
signed: bool
|
|
||||||
"If the type is signed (i.e. has a sign bit)"
|
|
||||||
|
|
||||||
_finite_values_only: bool = False
|
|
||||||
"""
|
|
||||||
Private: if NANs are supported, used `has_infs()` instead.
|
|
||||||
"""
|
|
||||||
|
|
||||||
nan_repr: int = NanRepr.IEEE_754.value
|
|
||||||
"""
|
|
||||||
How NaNs are represent in this scalar type, returns NanRepr value.
|
|
||||||
(not applicable for integer types)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def size_bits(self):
|
|
||||||
return self.exponent + self.mantissa + int(self.signed)
|
|
||||||
|
|
||||||
def min(self) -> Union[int, float]:
|
|
||||||
"""
|
|
||||||
Min representable value for this scalar type.
|
|
||||||
(accounting for bias if there is one)
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def max(self) -> Union[int, float]:
|
|
||||||
"""
|
|
||||||
Max representable value for this scalar type.
|
|
||||||
(accounting for bias if there is one)
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def is_signed(self) -> bool:
|
|
||||||
"""
|
|
||||||
If the type is signed (i.e. has a sign bit), same as `signed`
|
|
||||||
added for consistency with:
|
|
||||||
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def is_floating_point(self) -> bool:
|
|
||||||
"If the type is a floating point type"
|
|
||||||
return self.exponent != 0
|
|
||||||
|
|
||||||
def is_integer(self) -> bool:
|
|
||||||
"If the type is an integer type"
|
|
||||||
return self.exponent == 0
|
|
||||||
|
|
||||||
def has_bias(self) -> bool:
|
|
||||||
"If the type has a non-zero bias"
|
|
||||||
return self.bias != 0
|
|
||||||
|
|
||||||
def has_infs(self) -> bool:
|
|
||||||
"If the type is floating point and supports infinity"
|
|
||||||
return not self._finite_values_only
|
|
||||||
|
|
||||||
def has_nans(self) -> bool:
|
|
||||||
return self.nan_repr != NanRepr.NONE.value
|
|
||||||
|
|
||||||
def is_ieee_754(self) -> bool:
|
|
||||||
"""
|
|
||||||
If the type is a floating point type that follows IEEE 754
|
|
||||||
conventions
|
|
||||||
"""
|
|
||||||
return self.nan_repr == NanRepr.IEEE_754.value and \
|
|
||||||
not self._finite_values_only
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
|
||||||
# opcheck to work.
|
|
||||||
def __len__(self) -> int:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
#
|
|
||||||
# Convenience Constructors
|
|
||||||
#
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
|
||||||
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
|
||||||
return cls(size_bits - 1, size_bits, bias if bias else 0, True)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
|
||||||
"""Create a unsigned integer scalar type."""
|
|
||||||
return cls(size_bits, size_bits, bias if bias else 0, False)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
|
||||||
"""
|
|
||||||
Create a standard floating point type
|
|
||||||
(i.e. follows IEEE 754 conventions).
|
|
||||||
"""
|
|
||||||
return cls(exponent, mantissa, 0, True)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
|
||||||
nan_repr: int) -> 'ScalarType':
|
|
||||||
"""
|
|
||||||
Create a non-standard floating point type
|
|
||||||
(i.e. does not follow IEEE 754 conventions).
|
|
||||||
"""
|
|
||||||
return cls(exponent, mantissa, 0, True, finite_values_only,
|
|
||||||
nan_repr)
|
|
||||||
|
|
||||||
elif core_C_available:
|
|
||||||
try:
|
|
||||||
import vllm._core_C # noqa: F401
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning("Failed to import from vllm._core_C with %r", e)
|
|
||||||
|
|
||||||
ScalarType = torch.classes._core_C.ScalarType
|
|
||||||
|
|
||||||
if (hasattr(torch, "_library")
|
|
||||||
and hasattr(torch._library, "register_fake_class")):
|
|
||||||
# Needed for dynamo support of ScalarType.
|
|
||||||
@torch._library.register_fake_class("_core_C::ScalarType")
|
|
||||||
class FakeScalarType:
|
|
||||||
|
|
||||||
def __init__(self, scalar_type):
|
|
||||||
self.ScalarType = scalar_type
|
|
||||||
|
|
||||||
def bias_getter(self) -> int:
|
|
||||||
return self.ScalarType.bias
|
|
||||||
|
|
||||||
def exponent_getter(self) -> int:
|
|
||||||
return self.ScalarType.exponent
|
|
||||||
|
|
||||||
def mantissa_getter(self) -> int:
|
|
||||||
return self.ScalarType.mantissa
|
|
||||||
|
|
||||||
def signed_getter(self) -> bool:
|
|
||||||
return self.ScalarType.signed
|
|
||||||
|
|
||||||
def size_bits_getter(self) -> int:
|
|
||||||
return self.ScalarType.size_bits
|
|
||||||
|
|
||||||
@property
|
|
||||||
def size_bits(self) -> int:
|
|
||||||
return self.ScalarType.size_bits
|
|
||||||
|
|
||||||
def min(self) -> Union[int, float]:
|
|
||||||
return self.ScalarType.min()
|
|
||||||
|
|
||||||
def max(self) -> Union[int, float]:
|
|
||||||
return self.ScalarType.max()
|
|
||||||
|
|
||||||
def is_signed(self) -> bool:
|
|
||||||
return self.ScalarType.is_signed()
|
|
||||||
|
|
||||||
def is_floating_point(self) -> bool:
|
|
||||||
return self.ScalarType.is_floating_point()
|
|
||||||
|
|
||||||
def is_integer(self) -> bool:
|
|
||||||
return self.ScalarType.is_integer()
|
|
||||||
|
|
||||||
def has_bias(self) -> bool:
|
|
||||||
return self.ScalarType.has_bias()
|
|
||||||
|
|
||||||
def has_infs(self) -> bool:
|
|
||||||
return self.ScalarType.has_infs()
|
|
||||||
|
|
||||||
def has_nans(self) -> bool:
|
|
||||||
return self.ScalarType.has_nans()
|
|
||||||
|
|
||||||
def is_ieee_754(self) -> bool:
|
|
||||||
return self.ScalarType.is_ieee_754()
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.ScalarType.__str__()
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return self.ScalarType.__repr__()
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self.ScalarType.__len__()
|
|
||||||
|
|
||||||
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
|
|
||||||
return torch.classes._core_C.ScalarType.__obj_flatten__(
|
|
||||||
self.ScalarType)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __obj_unflatten__(
|
|
||||||
cls, flat_type: Tuple[Tuple[str, Any],
|
|
||||||
...]) -> 'ScalarType':
|
|
||||||
return cls(
|
|
||||||
torch.classes._core_C.ScalarType.__obj_unflatten__(
|
|
||||||
flat_type))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
|
||||||
return ScalarType.int_(size_bits, bias)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
|
||||||
return ScalarType.uint(size_bits, bias)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def float_IEEE754(cls, exponent: int,
|
|
||||||
mantissa: int) -> 'ScalarType':
|
|
||||||
return ScalarType.float_IEEE754(exponent, mantissa)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def float_(cls, exponent: int, mantissa: int,
|
|
||||||
finite_values_only: bool,
|
|
||||||
nan_repr: int) -> 'ScalarType':
|
|
||||||
return ScalarType.float_(exponent, mantissa,
|
|
||||||
finite_values_only, nan_repr)
|
|
@ -6,9 +6,9 @@ import torch
|
|||||||
import torch.library
|
import torch.library
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm._core_ext import ScalarType
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.scalar_type import ScalarType
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -306,7 +306,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|||||||
workspace: torch.Tensor, b_q_type: ScalarType,
|
workspace: torch.Tensor, b_q_type: ScalarType,
|
||||||
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
|
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
|
||||||
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
|
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
|
||||||
workspace, b_q_type, size_m,
|
workspace, b_q_type.id, size_m,
|
||||||
size_n, size_k)
|
size_n, size_k)
|
||||||
|
|
||||||
|
|
||||||
@ -316,8 +316,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||||
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
||||||
workspace: torch.Tensor,
|
workspace: torch.Tensor,
|
||||||
b_q_type: ScalarType, size_m: int,
|
b_q_type: ScalarType, size_m: torch.SymInt,
|
||||||
size_n: int, size_k: int) -> torch.Tensor:
|
size_n: torch.SymInt,
|
||||||
|
size_k: torch.SymInt) -> torch.Tensor:
|
||||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||||
|
|
||||||
@register_fake("_C::gptq_marlin_gemm")
|
@register_fake("_C::gptq_marlin_gemm")
|
||||||
@ -329,17 +330,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
perm: torch.Tensor,
|
perm: torch.Tensor,
|
||||||
workspace: torch.Tensor,
|
workspace: torch.Tensor,
|
||||||
b_q_type: ScalarType,
|
b_q_type: ScalarType,
|
||||||
size_m: int,
|
size_m: torch.SymInt,
|
||||||
size_n: int,
|
size_n: torch.SymInt,
|
||||||
size_k: int,
|
size_k: torch.SymInt,
|
||||||
is_k_full: bool,
|
is_k_full: bool,
|
||||||
has_zp: bool = False,
|
has_zp: bool = False,
|
||||||
use_fp32_reduce: bool = False) -> torch.Tensor:
|
use_fp32_reduce: bool = False) -> torch.Tensor:
|
||||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||||
|
|
||||||
@register_fake("_C::ggml_dequantize")
|
@register_fake("_C::ggml_dequantize")
|
||||||
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
|
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
|
||||||
n: int) -> torch.Tensor:
|
m: torch.SymInt,
|
||||||
|
n: torch.SymInt) -> torch.Tensor:
|
||||||
return torch.empty((m, n), dtype=torch.float16, device=W.device)
|
return torch.empty((m, n), dtype=torch.float16, device=W.device)
|
||||||
|
|
||||||
@register_fake("_C::ggml_mul_mat_vec_a8")
|
@register_fake("_C::ggml_mul_mat_vec_a8")
|
||||||
@ -347,7 +349,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
W: torch.Tensor,
|
W: torch.Tensor,
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
quant_type: int,
|
quant_type: int,
|
||||||
row: int,
|
row: torch.SymInt,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty((1, row), dtype=torch.float16, device=W.device)
|
return torch.empty((1, row), dtype=torch.float16, device=W.device)
|
||||||
|
|
||||||
@ -356,7 +358,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
W: torch.Tensor,
|
W: torch.Tensor,
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
quant_type: int,
|
quant_type: int,
|
||||||
row: int,
|
row: torch.SymInt,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
batch = X.size(0)
|
batch = X.size(0)
|
||||||
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
|
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
|
||||||
@ -365,8 +367,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||||
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
||||||
s_group: torch.Tensor, workspace: torch.Tensor,
|
s_group: torch.Tensor, workspace: torch.Tensor,
|
||||||
size_m: int, size_n: int,
|
size_m: torch.SymInt, size_n: torch.SymInt,
|
||||||
size_k: int) -> torch.Tensor:
|
size_k: torch.SymInt) -> torch.Tensor:
|
||||||
return torch.empty((size_m, size_n),
|
return torch.empty((size_m, size_n),
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
device=a.device)
|
device=a.device)
|
||||||
@ -374,16 +376,16 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
@register_fake("_C::marlin_gemm")
|
@register_fake("_C::marlin_gemm")
|
||||||
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||||
b_scales: torch.Tensor, workspace: torch.Tensor,
|
b_scales: torch.Tensor, workspace: torch.Tensor,
|
||||||
size_m: int, size_n: int,
|
size_m: torch.SymInt, size_n: torch.SymInt,
|
||||||
size_k: int) -> torch.Tensor:
|
size_k: torch.SymInt) -> torch.Tensor:
|
||||||
return torch.empty((size_m, size_n),
|
return torch.empty((size_m, size_n),
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
device=a.device)
|
device=a.device)
|
||||||
|
|
||||||
@register_fake("_C::awq_dequantize")
|
@register_fake("_C::awq_dequantize")
|
||||||
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
|
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
|
||||||
zeros: torch.Tensor, split_k_iters: int, thx: int,
|
zeros: torch.Tensor, split_k_iters: torch.SymInt,
|
||||||
thy: int) -> torch.Tensor:
|
thx: int, thy: int) -> torch.Tensor:
|
||||||
in_c = qweight.size(0)
|
in_c = qweight.size(0)
|
||||||
qout_c = qweight.size(1)
|
qout_c = qweight.size(1)
|
||||||
out_c = qout_c * 8
|
out_c = qout_c * 8
|
||||||
@ -394,7 +396,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
@register_fake("_C::awq_gemm")
|
@register_fake("_C::awq_gemm")
|
||||||
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
|
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
|
||||||
qzeros: torch.Tensor, scales: torch.Tensor,
|
qzeros: torch.Tensor, scales: torch.Tensor,
|
||||||
split_k_iters: int) -> torch.Tensor:
|
split_k_iters: torch.SymInt) -> torch.Tensor:
|
||||||
num_in_feats = input.size(0)
|
num_in_feats = input.size(0)
|
||||||
return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
|
return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
|
||||||
dtype=input.dtype,
|
dtype=input.dtype,
|
||||||
@ -429,8 +431,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
@register_fake("_C::fp8_marlin_gemm")
|
@register_fake("_C::fp8_marlin_gemm")
|
||||||
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||||
b_scales: torch.Tensor, workspace: torch.Tensor,
|
b_scales: torch.Tensor, workspace: torch.Tensor,
|
||||||
num_bits: int, size_m: int, size_n: int,
|
num_bits: int, size_m: torch.SymInt,
|
||||||
size_k: int) -> torch.Tensor:
|
size_n: torch.SymInt,
|
||||||
|
size_k: torch.SymInt) -> torch.Tensor:
|
||||||
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
||||||
|
|
||||||
@register_fake("_C::machete_gemm")
|
@register_fake("_C::machete_gemm")
|
||||||
@ -457,40 +460,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
return torch.empty_like(b_q_weight,
|
return torch.empty_like(b_q_weight,
|
||||||
memory_format=torch.contiguous_format)
|
memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
@register_fake("_C::causal_conv1d_fwd")
|
|
||||||
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
|
|
||||||
bias_: Optional[torch.Tensor],
|
|
||||||
conv_states: Optional[torch.Tensor],
|
|
||||||
cu_seq_len: Optional[torch.Tensor],
|
|
||||||
cache_indices: Optional[torch.Tensor],
|
|
||||||
has_initial_state: Optional[torch.Tensor],
|
|
||||||
silu_activation: bool, pad_slot_id: int):
|
|
||||||
return None
|
|
||||||
|
|
||||||
@register_fake("_C::causal_conv1d_update")
|
|
||||||
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
bias_: Optional[torch.Tensor],
|
|
||||||
silu_activation: bool,
|
|
||||||
cache_seqlens: Optional[torch.Tensor],
|
|
||||||
conv_state_indices: Optional[torch.Tensor],
|
|
||||||
pad_slot_id: int) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@register_fake("_C::selective_scan_fwd")
|
|
||||||
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
|
|
||||||
A: torch.Tensor, B: torch.Tensor,
|
|
||||||
C: torch.Tensor, D_: Optional[torch.Tensor],
|
|
||||||
z_: Optional[torch.Tensor],
|
|
||||||
delta_bias_: Optional[torch.Tensor],
|
|
||||||
delta_softplus: bool,
|
|
||||||
cu_seq_len: Optional[torch.Tensor],
|
|
||||||
cache_indices: Optional[torch.Tensor],
|
|
||||||
has_initial_state: Optional[torch.Tensor],
|
|
||||||
ssm_states: Optional[torch.Tensor],
|
|
||||||
pad_slot_id: int) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# cutlass
|
# cutlass
|
||||||
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
||||||
@ -611,7 +580,7 @@ def gptq_marlin_gemm(a: torch.Tensor,
|
|||||||
has_zp: bool = False,
|
has_zp: bool = False,
|
||||||
use_fp32_reduce: bool = False) -> torch.Tensor:
|
use_fp32_reduce: bool = False) -> torch.Tensor:
|
||||||
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
|
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
|
||||||
g_idx, perm, workspace, b_q_type,
|
g_idx, perm, workspace, b_q_type.id,
|
||||||
size_m, size_n, size_k, is_k_full,
|
size_m, size_n, size_k, is_k_full,
|
||||||
has_zp, use_fp32_reduce)
|
has_zp, use_fp32_reduce)
|
||||||
|
|
||||||
@ -627,7 +596,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|||||||
|
|
||||||
# machete
|
# machete
|
||||||
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
|
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
|
||||||
return torch.ops._C.machete_supported_schedules(b_type)
|
return torch.ops._C.machete_supported_schedules(b_type.id)
|
||||||
|
|
||||||
|
|
||||||
def machete_gemm(
|
def machete_gemm(
|
||||||
@ -642,13 +611,13 @@ def machete_gemm(
|
|||||||
beta: Optional[float] = None,
|
beta: Optional[float] = None,
|
||||||
schedule: Optional[str] = None,
|
schedule: Optional[str] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros,
|
return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros,
|
||||||
b_group_size, c, alpha, beta, schedule)
|
b_group_size, c, alpha, beta, schedule)
|
||||||
|
|
||||||
|
|
||||||
def machete_prepack_B(b_q_weight: torch.Tensor,
|
def machete_prepack_B(b_q_weight: torch.Tensor,
|
||||||
b_type: ScalarType) -> torch.Tensor:
|
b_type: ScalarType) -> torch.Tensor:
|
||||||
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
|
return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id)
|
||||||
|
|
||||||
|
|
||||||
if hasattr(torch.ops._C, "permute_cols"):
|
if hasattr(torch.ops._C, "permute_cols"):
|
||||||
@ -862,10 +831,10 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
|||||||
topk_ids: torch.Tensor, b_scales: torch.Tensor,
|
topk_ids: torch.Tensor, b_scales: torch.Tensor,
|
||||||
b_zero_points: torch.Tensor, g_idx: torch.Tensor,
|
b_zero_points: torch.Tensor, g_idx: torch.Tensor,
|
||||||
perm: torch.Tensor, workspace: torch.Tensor,
|
perm: torch.Tensor, workspace: torch.Tensor,
|
||||||
b_q_type: ScalarType, size_m: int, size_n: int,
|
b_q_type: ScalarType, size_m: torch.SymInt,
|
||||||
size_k: int, is_k_full: bool, num_experts: int,
|
size_n: torch.SymInt, size_k: torch.SymInt,
|
||||||
topk: int, moe_block_size: int,
|
is_k_full: bool, num_experts: int, topk: int,
|
||||||
replicate_input: bool,
|
moe_block_size: int, replicate_input: bool,
|
||||||
apply_weights: bool) -> torch.Tensor:
|
apply_weights: bool) -> torch.Tensor:
|
||||||
return torch.empty((size_m, topk, size_n),
|
return torch.empty((size_m, topk, size_n),
|
||||||
dtype=a.dtype,
|
dtype=a.dtype,
|
||||||
|
@ -116,7 +116,7 @@ def single_marlin_moe(
|
|||||||
|
|
||||||
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
|
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
|
||||||
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
|
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
|
||||||
w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K,
|
w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K,
|
||||||
is_k_full, E, topk, block_size_m, True, False)
|
is_k_full, E, topk, block_size_m, True, False)
|
||||||
|
|
||||||
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
||||||
@ -272,7 +272,7 @@ def fused_marlin_moe(
|
|||||||
g_idx1,
|
g_idx1,
|
||||||
sort_indices1,
|
sort_indices1,
|
||||||
workspace,
|
workspace,
|
||||||
scalar_type1,
|
scalar_type1.id,
|
||||||
M,
|
M,
|
||||||
2 * N,
|
2 * N,
|
||||||
K,
|
K,
|
||||||
@ -297,7 +297,7 @@ def fused_marlin_moe(
|
|||||||
g_idx2,
|
g_idx2,
|
||||||
sort_indices2,
|
sort_indices2,
|
||||||
workspace,
|
workspace,
|
||||||
scalar_type2,
|
scalar_type2.id,
|
||||||
M,
|
M,
|
||||||
K,
|
K,
|
||||||
N,
|
N,
|
||||||
|
@ -1,4 +1,298 @@
|
|||||||
from ._core_ext import NanRepr, ScalarType
|
import functools
|
||||||
|
import struct
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
# Mirrors enum in `core/scalar_type.hpp`
|
||||||
|
class NanRepr(Enum):
|
||||||
|
NONE = 0 # nans are not supported
|
||||||
|
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
||||||
|
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
||||||
|
|
||||||
|
|
||||||
|
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
||||||
|
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
||||||
|
# in sync until the inductor fully supports custom C++ classes.
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ScalarType:
|
||||||
|
"""
|
||||||
|
ScalarType can represent a wide range of floating point and integer
|
||||||
|
types, in particular it can be used to represent sub-byte data types
|
||||||
|
(something that torch.dtype currently does not support). It is also
|
||||||
|
capable of representing types with a bias, i.e.:
|
||||||
|
`stored_value = value + bias`,
|
||||||
|
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
||||||
|
of 8). The implementation for this class can be found in
|
||||||
|
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
||||||
|
with that file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
exponent: int
|
||||||
|
"""
|
||||||
|
Number of bits in the exponent if this is a floating point type
|
||||||
|
(zero if this an integer type)
|
||||||
|
"""
|
||||||
|
|
||||||
|
mantissa: int
|
||||||
|
"""
|
||||||
|
Number of bits in the mantissa if this is a floating point type,
|
||||||
|
or the number bits representing an integer excluding the sign bit if
|
||||||
|
this an integer type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
signed: bool
|
||||||
|
"If the type is signed (i.e. has a sign bit)"
|
||||||
|
|
||||||
|
bias: int
|
||||||
|
"""
|
||||||
|
bias used to encode the values in this scalar type
|
||||||
|
(value = stored_value - bias, default 0) for example if we store the
|
||||||
|
type as an unsigned integer with a bias of 128 then the value 0 will be
|
||||||
|
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_finite_values_only: bool = False
|
||||||
|
"""
|
||||||
|
Private: if infs are supported, used `has_infs()` instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
nan_repr: NanRepr = NanRepr.IEEE_754
|
||||||
|
"""
|
||||||
|
How NaNs are represent in this scalar type, returns NanRepr value.
|
||||||
|
(not applicable for integer types)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _floating_point_max_int(self) -> int:
|
||||||
|
assert (
|
||||||
|
self.mantissa <= 52 and self.exponent <= 11
|
||||||
|
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||||
|
|
||||||
|
max_mantissa = (1 << self.mantissa) - 1
|
||||||
|
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
||||||
|
max_mantissa = max_mantissa - 1
|
||||||
|
|
||||||
|
max_exponent = (1 << self.exponent) - 2
|
||||||
|
if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
|
||||||
|
or self.nan_repr == NanRepr.NONE):
|
||||||
|
assert (
|
||||||
|
self.exponent < 11
|
||||||
|
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||||
|
max_exponent = max_exponent + 1
|
||||||
|
|
||||||
|
# adjust the exponent to match that of a double
|
||||||
|
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
||||||
|
# e is the exponent bits), there is some precedent for non-standard
|
||||||
|
# biases, example `float8_e4m3b11fnuz` here:
|
||||||
|
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
||||||
|
# complication we are just assuming the standard exponent bias until
|
||||||
|
# there is a need to support non-standard biases
|
||||||
|
exponent_bias = (1 << (self.exponent - 1)) - 1
|
||||||
|
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
||||||
|
|
||||||
|
max_exponent_double = (max_exponent - exponent_bias +
|
||||||
|
exponent_bias_double)
|
||||||
|
|
||||||
|
# shift the mantissa and exponent into the proper positions for an
|
||||||
|
# IEEE double and bitwise-or them together.
|
||||||
|
return (max_mantissa <<
|
||||||
|
(52 - self.mantissa)) | (max_exponent_double << 52)
|
||||||
|
|
||||||
|
def _floating_point_max(self) -> float:
|
||||||
|
double_raw = self._floating_point_max_int()
|
||||||
|
return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
|
||||||
|
|
||||||
|
def _raw_max(self) -> Union[int, float]:
|
||||||
|
if self.is_floating_point():
|
||||||
|
return self._floating_point_max()
|
||||||
|
else:
|
||||||
|
assert (self.size_bits < 64 or self.size_bits == 64
|
||||||
|
and self.is_signed()), "Cannot represent max as an int"
|
||||||
|
return (1 << self.mantissa) - 1
|
||||||
|
|
||||||
|
def _raw_min(self) -> Union[int, float]:
|
||||||
|
if self.is_floating_point():
|
||||||
|
assert self.is_signed(
|
||||||
|
), "We currently assume all floating point types are signed"
|
||||||
|
sign_bit_double = 1 << 63
|
||||||
|
|
||||||
|
max_raw = self._floating_point_max_int()
|
||||||
|
min_raw = max_raw | sign_bit_double
|
||||||
|
return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
|
||||||
|
else:
|
||||||
|
assert (not self.is_signed() or
|
||||||
|
self.size_bits <= 64), "Cannot represent min as a int64_t"
|
||||||
|
|
||||||
|
if self.is_signed():
|
||||||
|
return -(1 << (self.size_bits - 1))
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def id(self) -> int:
|
||||||
|
"""
|
||||||
|
Convert the ScalarType to an int which can be passed to pytorch custom
|
||||||
|
ops. This layout of the int must be kept in sync with the C++
|
||||||
|
ScalarType's from_id method.
|
||||||
|
"""
|
||||||
|
val = 0
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
def or_and_advance(member, bit_width):
|
||||||
|
nonlocal val
|
||||||
|
nonlocal offset
|
||||||
|
bit_mask = (1 << bit_width) - 1
|
||||||
|
val = val | (int(member) & bit_mask) << offset
|
||||||
|
offset = offset + bit_width
|
||||||
|
|
||||||
|
or_and_advance(self.exponent, 8)
|
||||||
|
or_and_advance(self.mantissa, 8)
|
||||||
|
or_and_advance(self.signed, 1)
|
||||||
|
or_and_advance(self.bias, 32)
|
||||||
|
or_and_advance(self._finite_values_only, 1)
|
||||||
|
or_and_advance(self.nan_repr.value, 8)
|
||||||
|
|
||||||
|
assert offset <= 64, \
|
||||||
|
f"ScalarType fields too big {offset} to fit into an int64"
|
||||||
|
|
||||||
|
return val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size_bits(self) -> int:
|
||||||
|
return self.exponent + self.mantissa + int(self.signed)
|
||||||
|
|
||||||
|
def min(self) -> Union[int, float]:
|
||||||
|
"""
|
||||||
|
Min representable value for this scalar type.
|
||||||
|
(accounting for bias if there is one)
|
||||||
|
"""
|
||||||
|
return self._raw_min() - self.bias
|
||||||
|
|
||||||
|
def max(self) -> Union[int, float]:
|
||||||
|
"""
|
||||||
|
Max representable value for this scalar type.
|
||||||
|
(accounting for bias if there is one)
|
||||||
|
"""
|
||||||
|
return self._raw_max() - self.bias
|
||||||
|
|
||||||
|
def is_signed(self) -> bool:
|
||||||
|
"""
|
||||||
|
If the type is signed (i.e. has a sign bit), same as `signed`
|
||||||
|
added for consistency with:
|
||||||
|
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
||||||
|
"""
|
||||||
|
return self.signed
|
||||||
|
|
||||||
|
def is_floating_point(self) -> bool:
|
||||||
|
"If the type is a floating point type"
|
||||||
|
return self.exponent != 0
|
||||||
|
|
||||||
|
def is_integer(self) -> bool:
|
||||||
|
"If the type is an integer type"
|
||||||
|
return self.exponent == 0
|
||||||
|
|
||||||
|
def has_bias(self) -> bool:
|
||||||
|
"If the type has a non-zero bias"
|
||||||
|
return self.bias != 0
|
||||||
|
|
||||||
|
def has_infs(self) -> bool:
|
||||||
|
"If the type is floating point and supports infinity"
|
||||||
|
return not self._finite_values_only
|
||||||
|
|
||||||
|
def has_nans(self) -> bool:
|
||||||
|
return self.nan_repr != NanRepr.NONE.value
|
||||||
|
|
||||||
|
def is_ieee_754(self) -> bool:
|
||||||
|
"""
|
||||||
|
If the type is a floating point type that follows IEEE 754
|
||||||
|
conventions
|
||||||
|
"""
|
||||||
|
return self.nan_repr == NanRepr.IEEE_754.value and \
|
||||||
|
not self._finite_values_only
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""
|
||||||
|
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||||
|
for floating point types (leading f) the scheme is:
|
||||||
|
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||||
|
flags:
|
||||||
|
- no-flags: means it follows IEEE 754 conventions
|
||||||
|
- f: means finite values only (no infinities)
|
||||||
|
- n: means nans are supported (non-standard encoding)
|
||||||
|
for integer types the scheme is:
|
||||||
|
`[u]int<size_bits>[b<bias>]`
|
||||||
|
- if bias is not present it means its zero
|
||||||
|
"""
|
||||||
|
if self.is_floating_point():
|
||||||
|
ret = "float" + str(self.size_bits) + "_e" + str(
|
||||||
|
self.exponent) + "m" + str(self.mantissa)
|
||||||
|
|
||||||
|
if not self.is_ieee_754():
|
||||||
|
if self._finite_values_only:
|
||||||
|
ret = ret + "f"
|
||||||
|
if self.nan_repr != NanRepr.NONE:
|
||||||
|
ret = ret + "n"
|
||||||
|
|
||||||
|
return ret
|
||||||
|
else:
|
||||||
|
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
||||||
|
if self.has_bias():
|
||||||
|
ret = ret + "b" + str(self.bias)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "ScalarType." + self.__str__()
|
||||||
|
|
||||||
|
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
||||||
|
# opcheck to work.
|
||||||
|
def __len__(self) -> int:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
#
|
||||||
|
# Convenience Constructors
|
||||||
|
#
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||||
|
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
||||||
|
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
||||||
|
ret.id # noqa B018: make sure the id is cached
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||||
|
"""Create a unsigned integer scalar type."""
|
||||||
|
ret = cls(0, size_bits, False, bias if bias else 0)
|
||||||
|
ret.id # noqa B018: make sure the id is cached
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
||||||
|
"""
|
||||||
|
Create a standard floating point type
|
||||||
|
(i.e. follows IEEE 754 conventions).
|
||||||
|
"""
|
||||||
|
assert (mantissa > 0 and exponent > 0)
|
||||||
|
ret = cls(exponent, mantissa, True, 0)
|
||||||
|
ret.id # noqa B018: make sure the id is cached
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
||||||
|
nan_repr: NanRepr) -> 'ScalarType':
|
||||||
|
"""
|
||||||
|
Create a non-standard floating point type
|
||||||
|
(i.e. does not follow IEEE 754 conventions).
|
||||||
|
"""
|
||||||
|
assert (mantissa > 0 and exponent > 0)
|
||||||
|
assert (nan_repr != NanRepr.IEEE_754), (
|
||||||
|
"use `float_IEEE754` constructor for floating point types that "
|
||||||
|
"follow IEEE 754 conventions")
|
||||||
|
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
||||||
|
ret.id # noqa B018: make sure the id is cached
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||||
# for floating point types (leading f) the scheme is:
|
# for floating point types (leading f) the scheme is:
|
||||||
@ -17,14 +311,13 @@ class scalar_types:
|
|||||||
uint4 = ScalarType.uint(4, None)
|
uint4 = ScalarType.uint(4, None)
|
||||||
int8 = ScalarType.int_(8, None)
|
int8 = ScalarType.int_(8, None)
|
||||||
uint8 = ScalarType.uint(8, None)
|
uint8 = ScalarType.uint(8, None)
|
||||||
float8_e4m3fn = ScalarType.float_(4, 3, True,
|
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
||||||
NanRepr.EXTD_RANGE_MAX_MIN.value)
|
|
||||||
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
||||||
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
||||||
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
||||||
|
|
||||||
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
||||||
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)
|
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
||||||
|
|
||||||
# "gptq" types
|
# "gptq" types
|
||||||
uint2b2 = ScalarType.uint(2, 2)
|
uint2b2 = ScalarType.uint(2, 2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user