[Bugfix] Fix support for dimension like integers and ScalarType (#9299)
This commit is contained in:
parent
0f41fbe5a3
commit
eca2c5f7c0
@ -230,14 +230,12 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
|
||||
# TODO: re-write in comparison tests, and fix symbolic shape
|
||||
# for quantization ops.
|
||||
# - label: "PyTorch Fullgraph Test" # 18min
|
||||
# source_file_dependencies:
|
||||
# - vllm/
|
||||
# - tests/compile
|
||||
# commands:
|
||||
# - pytest -v -s compile/test_full_graph.py
|
||||
- label: "PyTorch Fullgraph Test" # 18min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph.py
|
||||
|
||||
- label: Kernels Test %N # 1h each
|
||||
mirror_hardwares: [amd]
|
||||
|
@ -83,24 +83,6 @@ endif()
|
||||
#
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
#
|
||||
message(STATUS "Enabling core extension.")
|
||||
|
||||
# Define _core_C extension
|
||||
# built for (almost) every target platform, (excludes TPU and Neuron)
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/core/torch_bindings.cpp")
|
||||
|
||||
define_gpu_extension_target(
|
||||
_core_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
#
|
||||
# Forward the non-CUDA device extensions to external CMake scripts.
|
||||
#
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/custom_class.h>
|
||||
// For TORCH_CHECK
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@ -9,12 +10,7 @@ namespace vllm {
|
||||
// in particular it can be used to represent sub-byte data types (something
|
||||
// that torch.dtype currently does not support).
|
||||
//
|
||||
// ScalarTypeTorch is a subclass of ScalarType that is compatible with
|
||||
// TORCH_LIBRARY, making it accessible from Python as well meaning this class
|
||||
// can be used as a argument for custom operators, helping to simplify these
|
||||
// interfaces.
|
||||
//
|
||||
// The type definitions on the Python side can be found in: vllm/_core_ext.pyi
|
||||
// The type definitions on the Python side can be found in: vllm/scalar_type.py
|
||||
// these type definitions should be kept up to date with any Python API changes
|
||||
// here.
|
||||
//
|
||||
@ -308,204 +304,7 @@ class ScalarType {
|
||||
}
|
||||
};
|
||||
|
||||
// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from
|
||||
// torch::CustomClassHolder), we use multiple inheritance here since we cannot
|
||||
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
|
||||
// constructor at the same time (torch::CustomClassHolder does not have a
|
||||
// constexpr destructor)
|
||||
// See also:
|
||||
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
||||
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
||||
public:
|
||||
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
|
||||
bool _signed)
|
||||
: ScalarType(exponent, mantissa, bias, _signed){};
|
||||
|
||||
ScalarTypeTorch(ScalarType type) : ScalarType(type){};
|
||||
|
||||
using Base = ScalarType;
|
||||
using Self = ScalarTypeTorch;
|
||||
using SelfPtr = c10::intrusive_ptr<Self>;
|
||||
|
||||
static void check_size_bits(int64_t size_bits, bool signed_) {
|
||||
TORCH_CHECK(
|
||||
size_bits <=
|
||||
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
|
||||
"size_bits bit width is too large to be represented");
|
||||
}
|
||||
|
||||
static void check_bias(int64_t bias) {
|
||||
using Bias = decltype(std::declval<Self>().bias);
|
||||
TORCH_CHECK(bias <= std::numeric_limits<Bias>::max() &&
|
||||
bias >= std::numeric_limits<Bias>::min(),
|
||||
"bias too large or small to be represented");
|
||||
}
|
||||
|
||||
static void check_exponent(int64_t exponent) {
|
||||
TORCH_CHECK(
|
||||
exponent <=
|
||||
std::numeric_limits<decltype(std::declval<Self>().exponent)>::max(),
|
||||
"exponent bit width is too large to be represented");
|
||||
}
|
||||
|
||||
static void check_mantissa(int64_t mantissa) {
|
||||
TORCH_CHECK(
|
||||
mantissa <=
|
||||
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
|
||||
"mantissa bit width is too large to be represented");
|
||||
}
|
||||
|
||||
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
|
||||
check_size_bits(size_bits, true);
|
||||
check_bias(bias.value_or(0));
|
||||
return c10::make_intrusive<Self>(
|
||||
ScalarType::int_(size_bits, bias.value_or(0)));
|
||||
}
|
||||
|
||||
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
|
||||
check_size_bits(size_bits, true);
|
||||
check_bias(bias.value_or(0));
|
||||
return c10::make_intrusive<Self>(
|
||||
ScalarType::uint(size_bits, bias.value_or(0)));
|
||||
}
|
||||
|
||||
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
|
||||
check_mantissa(mantissa);
|
||||
check_exponent(exponent);
|
||||
return c10::make_intrusive<Self>(
|
||||
ScalarType::float_IEEE754(exponent, mantissa));
|
||||
}
|
||||
|
||||
static SelfPtr float_(int64_t exponent, int64_t mantissa,
|
||||
bool finite_values_only, int64_t nan_repr) {
|
||||
check_mantissa(mantissa);
|
||||
check_exponent(exponent);
|
||||
return c10::make_intrusive<Self>(ScalarType::float_(
|
||||
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
|
||||
}
|
||||
|
||||
// This needs to be implemented and throw a TypeError in order for
|
||||
// PyTorch's opcheck to work on ops that use ScalarTypes.
|
||||
int64_t len() const {
|
||||
throw c10::TypeError({__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
|
||||
"__len__ not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Serialize a ScalarType into a tuple of pairs. Where each pair
|
||||
// is a (fieldname, value).
|
||||
// For simplicity, we are just going to convert to a ScalarTypeId.
|
||||
std::tuple<std::tuple<std::string, int64_t>> obj_flatten() const {
|
||||
return {{"ScalarType", id()}};
|
||||
}
|
||||
|
||||
// Deserialize a scalar type that has been serialized by obj_flatten,
|
||||
// ostensibly from a tuple of (member name, value) pairs, but in reality
|
||||
// just a ScalarTypeId.
|
||||
static SelfPtr obj_unflatten(
|
||||
std::tuple<std::tuple<std::string, int64_t>> const& flat_type) {
|
||||
return c10::make_intrusive<Self>(
|
||||
from_id(std::get<1>(std::get<0>(flat_type))));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void bind_readonly_property(torch::class_<Self>& cls,
|
||||
std::string const& name, T Base::*field) {
|
||||
auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) {
|
||||
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
|
||||
return (self.get()->*field)();
|
||||
} else {
|
||||
return self.get()->*field;
|
||||
}
|
||||
};
|
||||
|
||||
auto getter_func = [field = std::move(field),
|
||||
getter_func_helper = std::move(getter_func_helper)](
|
||||
SelfPtr const& self) {
|
||||
auto val = getter_func_helper(self);
|
||||
// upconvert uint8_t, int32_t etc. to int64_t for python
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
return static_cast<int64_t>(val);
|
||||
} else {
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
cls.def_property(name, getter_func);
|
||||
}
|
||||
|
||||
template <typename MemberFunc, typename Cls>
|
||||
static void bind_function(torch::class_<Self>& cls, const std::string& name,
|
||||
MemberFunc Cls::*member) {
|
||||
cls.def(name, [member = std::move(member)](SelfPtr const& self) {
|
||||
return (self.get()->*member)();
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
static void bind_function(torch::class_<Self>& cls, const std::string& name,
|
||||
Func func) {
|
||||
cls.def(name, func);
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
static void bind_static_function(torch::class_<Self>& cls,
|
||||
const std::string& name, Func func) {
|
||||
cls.def_static(name, func);
|
||||
}
|
||||
|
||||
static void bind_class(torch::Library& lib) {
|
||||
auto cls = lib.class_<ScalarTypeTorch>("ScalarType")
|
||||
.def(torch::init<int64_t, int64_t, int64_t, bool>());
|
||||
|
||||
// Bind Properties
|
||||
bind_readonly_property(cls, "mantissa", &Base::mantissa);
|
||||
bind_readonly_property(cls, "exponent", &Base::exponent);
|
||||
bind_readonly_property(cls, "bias", &Base::bias);
|
||||
bind_readonly_property(cls, "signed", &Base::is_signed);
|
||||
bind_readonly_property(cls, "size_bits", &Base::size_bits);
|
||||
|
||||
// Bind member functions
|
||||
bind_function(cls, "is_signed", &Base::is_signed);
|
||||
bind_function(cls, "is_integer", &Base::is_integer);
|
||||
bind_function(cls, "is_floating_point", &Base::is_floating_point);
|
||||
bind_function(cls, "is_ieee_754", &Base::is_ieee_754);
|
||||
bind_function(cls, "has_nans", &Base::has_nans);
|
||||
bind_function(cls, "has_infs", &Base::has_infs);
|
||||
bind_function(cls, "has_bias", &Base::has_bias);
|
||||
|
||||
bind_function(cls, "max", [](SelfPtr const& self) {
|
||||
return std::visit([](auto arg) { return c10::IValue(arg); },
|
||||
self.get()->max());
|
||||
});
|
||||
bind_function(cls, "min", [](SelfPtr const& self) {
|
||||
return std::visit([](auto arg) { return c10::IValue(arg); },
|
||||
self.get()->min());
|
||||
});
|
||||
|
||||
bind_function(cls, "__len__", &ScalarTypeTorch::len);
|
||||
bind_function(cls, "__str__", &Base::str);
|
||||
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
|
||||
return *self == *other;
|
||||
});
|
||||
bind_function(cls, "__repr__", [](SelfPtr const& self) {
|
||||
return "ScalarType." + self.get()->str();
|
||||
});
|
||||
|
||||
bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten);
|
||||
bind_static_function(cls, "__obj_unflatten__",
|
||||
&ScalarTypeTorch::obj_unflatten);
|
||||
|
||||
// Bind static functions (convenience constructors)
|
||||
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
|
||||
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
|
||||
bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754);
|
||||
bind_static_function(cls, "float_", &ScalarTypeTorch::float_);
|
||||
}
|
||||
};
|
||||
|
||||
using ScalarTypeId = int64_t;
|
||||
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
|
||||
using ScalarTypeId = ScalarType::Id;
|
||||
|
||||
// "rust style" names generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
||||
|
@ -1,16 +0,0 @@
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "scalar_type.hpp"
|
||||
#include "registration.h"
|
||||
|
||||
// Note the CORE exstension will be built for (almost) all hardware targets so
|
||||
// new additions must account for this. (currently not built for TPU and Neuron)
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) {
|
||||
// ScalarType, a custom class for representing data types that supports
|
||||
// quantized types, declared here so it can be used when creating interfaces
|
||||
// for custom ops.
|
||||
vllm::ScalarTypeTorch::bind_class(lib);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
@ -484,21 +484,22 @@ torch::Tensor marlin_gemm_moe(
|
||||
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
||||
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
|
||||
const torch::Tensor& perm, torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n,
|
||||
vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
|
||||
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
bool has_zp = b_zeros.size(1) != 0;
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
*b_q_type == vllm::kU4,
|
||||
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str());
|
||||
b_q_type == vllm::kU4,
|
||||
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());
|
||||
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str());
|
||||
}
|
||||
|
||||
int pack_factor = 32 / b_q_type->size_bits();
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
|
||||
int max_par = 4;
|
||||
|
||||
@ -575,7 +576,7 @@ torch::Tensor marlin_gemm_moe(
|
||||
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
|
||||
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
|
||||
*b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
|
||||
b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
|
||||
num_experts, topk, moe_block_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
|
||||
replicate_input, apply_weights);
|
||||
|
@ -13,8 +13,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
|
||||
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
|
||||
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
|
||||
"int b_q_type, SymInt size_m, "
|
||||
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
|
||||
"topk, "
|
||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||
" -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
@ -80,7 +80,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
vllm::ScalarTypeId const b_q_type_id,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
@ -2132,22 +2132,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
vllm::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ",
|
||||
b_q_type->str());
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
||||
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
b_q_type->str());
|
||||
b_q_type.str());
|
||||
}
|
||||
|
||||
int pack_factor = 32 / b_q_type->size_bits();
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
|
||||
// Verify A
|
||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||
@ -2279,7 +2280,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
@ -2288,7 +2289,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else {
|
||||
@ -2302,4 +2303,4 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
|
||||
}
|
||||
}
|
||||
|
@ -38,9 +38,10 @@ static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
|
||||
// Interface
|
||||
//
|
||||
|
||||
std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
|
||||
std::vector<std::string> supported_schedules(ScalarTypeId const btype_id) {
|
||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||
return scalar_type_dispatch(*btype, [&](auto BType) {
|
||||
vllm::ScalarType b_type = ScalarType::from_id(btype_id);
|
||||
return scalar_type_dispatch(b_type, [&](auto BType) {
|
||||
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
|
||||
});
|
||||
#else
|
||||
@ -49,7 +50,7 @@ std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
|
||||
}
|
||||
|
||||
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
ScalarTypeTorchPtr const& btype,
|
||||
ScalarTypeId const btype_id,
|
||||
c10::optional<torch::Tensor> const& scales,
|
||||
c10::optional<torch::Tensor> const& zeros,
|
||||
c10::optional<int64_t> group_size,
|
||||
@ -57,6 +58,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
c10::optional<double> alpha, c10::optional<double> beta,
|
||||
c10::optional<std::string> schedule) {
|
||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||
ScalarType const btype = ScalarType::from_id(btype_id);
|
||||
auto args = PyTorchArguments{.A = A,
|
||||
.B = B,
|
||||
.scales = scales,
|
||||
@ -67,7 +69,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
.beta = beta,
|
||||
.schedule = schedule};
|
||||
|
||||
return scalar_type_dispatch(*btype, [&](auto BType) {
|
||||
return scalar_type_dispatch(btype, [&](auto BType) {
|
||||
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
|
||||
A.scalar_type(), "machete_gemm", [&] {
|
||||
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
|
||||
@ -79,9 +81,9 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
#endif
|
||||
}
|
||||
|
||||
torch::Tensor prepack_B(torch::Tensor const& B,
|
||||
vllm::ScalarTypeTorchPtr const& btype) {
|
||||
return scalar_type_dispatch(*btype, [&](auto BType) {
|
||||
torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) {
|
||||
ScalarType const btype = ScalarType::from_id(btype_id);
|
||||
return scalar_type_dispatch(btype, [&](auto BType) {
|
||||
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
|
||||
});
|
||||
}
|
||||
|
@ -89,7 +89,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
vllm::ScalarTypeId const b_q_type_id,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
@ -1029,13 +1029,14 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
vllm::ScalarTypeId const b_q_type_id,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
// Verify num_bits
|
||||
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
||||
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str());
|
||||
int pack_factor = 32 / b_q_type->size_bits();
|
||||
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type.str());
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
|
||||
// Verify M
|
||||
TORCH_CHECK(size_m == a.size(0),
|
||||
@ -1130,8 +1131,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
marlin_24::marlin_cuda_2_4(
|
||||
a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
|
||||
b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
|
||||
b_q_type->size_bits(), groupsize, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par);
|
||||
b_q_type.size_bits(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_m, sms, max_par);
|
||||
|
||||
return c;
|
||||
}
|
||||
|
@ -140,13 +140,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// Quantized GEMM for AWQ.
|
||||
ops.def(
|
||||
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
|
||||
"Tensor _zeros, int split_k_iters) -> Tensor");
|
||||
"Tensor _zeros, SymInt split_k_iters) -> Tensor");
|
||||
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
|
||||
|
||||
// Dequantization for AWQ.
|
||||
ops.def(
|
||||
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
|
||||
"Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor");
|
||||
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
|
||||
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
||||
|
||||
// Note about marlin kernel 'workspace' arguments:
|
||||
@ -166,32 +166,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
|
||||
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
|
||||
"Tensor");
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
||||
"Tensor b_scales, Tensor workspace, "
|
||||
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
||||
"int size_m, int size_n, int size_k) -> Tensor");
|
||||
"int b_q_type, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
||||
ops.def("machete_supported_schedules(int btype) -> str[]");
|
||||
ops.def(
|
||||
"machete_supported_schedules("
|
||||
" __torch__.torch.classes._core_C.ScalarType btype"
|
||||
") -> str[]");
|
||||
ops.def(
|
||||
"machete_gemm(Tensor A, Tensor B,"
|
||||
" __torch__.torch.classes._core_C.ScalarType btype,"
|
||||
" Tensor? scales, Tensor? zeros, int? group_size,"
|
||||
"machete_gemm(Tensor A, Tensor B, int btype, "
|
||||
" Tensor? scales, Tensor? zeros, int? group_size, "
|
||||
" Tensor? C, float? alpha, float? beta, str? schedule)"
|
||||
"-> Tensor");
|
||||
ops.def(
|
||||
"machete_prepack_B(Tensor B,"
|
||||
" __torch__.torch.classes._core_C.ScalarType btype)"
|
||||
"-> Tensor");
|
||||
ops.def("machete_prepack_B(Tensor B, int btype) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
@ -201,8 +195,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
||||
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
||||
"int size_m, int size_n, int size_k, bool is_k_full, "
|
||||
"int b_q_type, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||
"bool has_zp, bool use_fp32_reduce) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
@ -219,32 +213,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// conditionally compiled so impl registrations are in source file
|
||||
|
||||
// Dequantization for GGML.
|
||||
ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
|
||||
ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor");
|
||||
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
|
||||
|
||||
// mmvq kernel for GGML.
|
||||
ops.def(
|
||||
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) "
|
||||
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
|
||||
"-> Tensor");
|
||||
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
|
||||
|
||||
// mmq kernel for GGML.
|
||||
ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor");
|
||||
ops.def(
|
||||
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
|
||||
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
|
||||
|
||||
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
||||
ops.def(
|
||||
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor! workspace, int num_bits, int size_m, int size_n, "
|
||||
"int size_k) -> Tensor");
|
||||
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
||||
"SymInt size_k) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// marlin_qqq_gemm for QQQ.
|
||||
ops.def(
|
||||
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
||||
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
||||
"Tensor! workspace, int size_m, int size_n, "
|
||||
"int size_k) -> Tensor");
|
||||
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
||||
"SymInt size_k) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
|
@ -39,7 +39,6 @@ assert cwd != package_path, "should not import from the current directory"
|
||||
|
||||
files_to_copy = [
|
||||
"vllm/_C.abi3.so",
|
||||
"vllm/_core_C.abi3.so",
|
||||
"vllm/_moe_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so",
|
||||
"vllm/vllm_flash_attn/flash_attn_interface.py",
|
||||
|
7
setup.py
7
setup.py
@ -290,10 +290,6 @@ def _build_custom_ops() -> bool:
|
||||
return _is_cuda() or _is_hip() or _is_cpu()
|
||||
|
||||
|
||||
def _build_core_ext() -> bool:
|
||||
return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu())
|
||||
|
||||
|
||||
def get_hipcc_rocm_version():
|
||||
# Run the hipcc --version command
|
||||
result = subprocess.run(['hipcc', '--version'],
|
||||
@ -456,9 +452,6 @@ def get_requirements() -> List[str]:
|
||||
|
||||
ext_modules = []
|
||||
|
||||
if _build_core_ext():
|
||||
ext_modules.append(CMakeExtension(name="vllm._core_C"))
|
||||
|
||||
if _is_cuda() or _is_hip():
|
||||
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
|
||||
|
||||
|
@ -69,11 +69,11 @@ def check_full_graph_support(model,
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
|
||||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
|
||||
|
||||
# Inductor doesn't support fp8/gptq_marlin_24 yet.
|
||||
# Inductor doesn't support fp8 and the base meta llama uses too
|
||||
# much memory.
|
||||
quantization = model_kwargs.get("quantization")
|
||||
if (quantization == "fp8" or quantization == "gptq_marlin"
|
||||
or quantization == "gptq_marlin_24"
|
||||
) and optimization_level >= CompilationLevel.INDUCTOR:
|
||||
if ((quantization == "fp8" or model == "meta-llama/Meta-Llama-3-8B")
|
||||
and optimization_level >= CompilationLevel.INDUCTOR):
|
||||
return
|
||||
|
||||
prompts = [
|
||||
|
@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor,
|
||||
w_q = w_q.t().contiguous().t() # convert to col major
|
||||
w_q_machete = ops.machete_prepack_B(w_q, wtype)
|
||||
|
||||
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype))
|
||||
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id))
|
||||
|
||||
return w_ref, w_q_machete, w_s, w_zp
|
||||
|
||||
@ -153,9 +153,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
|
||||
schedule=schedule,
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.machete_gemm,
|
||||
(a, w_q_machete, wtype, w_s, maybe_convert_zeropoints(
|
||||
w_zp, w_s), group_size, None, None, None, schedule))
|
||||
opcheck(
|
||||
torch.ops._C.machete_gemm,
|
||||
(a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints(
|
||||
w_zp, w_s), group_size, None, None, None, schedule))
|
||||
|
||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||
# Relax atol when we have zeropoints since the way machete applies
|
||||
|
@ -225,7 +225,7 @@ def test_gptq_marlin_gemm(
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
||||
workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1],
|
||||
workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
|
||||
a_input.shape[1], is_k_full, False, use_fp32_reduce),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
|
||||
@ -254,6 +254,16 @@ def test_gptq_marlin_gemm(
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
# TODO: find better way to test this?
|
||||
@torch.compile(fullgraph=True)
|
||||
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s, scratch, quant_type, size_m, size_n,
|
||||
size_k):
|
||||
return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s, scratch, quant_type, size_m,
|
||||
size_n, size_k)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
||||
@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_24_gemm,
|
||||
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
|
||||
workspace_24.scratch, quant_type, a_input.shape[0],
|
||||
workspace_24.scratch, quant_type.id, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1]),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
|
||||
output = ops.gptq_marlin_24_gemm(
|
||||
output = marlin_24_gemm_tester(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
|
@ -240,8 +240,8 @@ def test_fused_marlin_moe(
|
||||
requires_grad=False)
|
||||
opcheck(torch.ops._moe_C.marlin_gemm_moe,
|
||||
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
|
||||
scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m,
|
||||
2 * n, k, True, e, topk, block_size_m, True, False))
|
||||
scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
|
||||
m, 2 * n, k, True, e, topk, block_size_m, True, False))
|
||||
|
||||
|
||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
||||
|
@ -32,5 +32,5 @@ def test_scalar_type_min_max(type_tuple):
|
||||
max = torch.iinfo(torch_type).max
|
||||
|
||||
print(t, min, max, t.min(), t.max())
|
||||
assert min == t.min()
|
||||
assert max == t.max()
|
||||
assert min == t.min(), f"min: {min} != {t.min()}"
|
||||
assert max == t.max(), f"max: {max} != {t.max()}"
|
||||
|
@ -16,7 +16,6 @@ Typical output looks like this:
|
||||
2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time)
|
||||
3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time)
|
||||
Longest build steps for .so (linking):
|
||||
0.1 weighted s to build _core_C.abi3.so (0.7 s elapsed time)
|
||||
0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time)
|
||||
0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time)
|
||||
6.2 weighted s to build _C.abi3.so (6.2 s elapsed time)
|
||||
|
@ -1,278 +0,0 @@
|
||||
import importlib.util
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None
|
||||
|
||||
|
||||
# Mirrors enum in `core/scalar_type.hpp`
|
||||
class NanRepr(Enum):
|
||||
NONE = 0 # nans are not supported
|
||||
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
||||
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
||||
|
||||
|
||||
if TYPE_CHECKING or not core_C_available:
|
||||
# On platforms were we cannot use/build the C++ core extension (i.e. namely
|
||||
# neuron and tpu), we define the mock ScalarType class here that partially
|
||||
# mimics the C++ ScalarType class.
|
||||
#
|
||||
# We also use this provide type signatures to the Python LSP for the methods
|
||||
# in the C++ ScalarType class. So these type signatures should be kept
|
||||
# in sync with csrc/core/scalar_type.hpp
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScalarType:
|
||||
"""
|
||||
ScalarType can represent a wide range of floating point and integer
|
||||
types, in particular it can be used to represent sub-byte data types
|
||||
(something that torch.dtype currently does not support). It is also
|
||||
capable of representing types with a bias, i.e.:
|
||||
`stored_value = value + bias`,
|
||||
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
||||
of 8). The implementation for this class can be found in
|
||||
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
||||
with that file.
|
||||
"""
|
||||
|
||||
exponent: int
|
||||
"""
|
||||
Number of bits in the exponent if this is a floating point type
|
||||
(zero if this an integer type)
|
||||
"""
|
||||
|
||||
mantissa: int
|
||||
"""
|
||||
Number of bits in the mantissa if this is a floating point type,
|
||||
or the number bits representing an integer excluding the sign bit if
|
||||
this an integer type.
|
||||
"""
|
||||
|
||||
bias: int
|
||||
"""
|
||||
bias used to encode the values in this scalar type
|
||||
(value = stored_value - bias, default 0) for example if we store the
|
||||
type as an unsigned integer with a bias of 128 then the value 0 will be
|
||||
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
||||
"""
|
||||
|
||||
signed: bool
|
||||
"If the type is signed (i.e. has a sign bit)"
|
||||
|
||||
_finite_values_only: bool = False
|
||||
"""
|
||||
Private: if NANs are supported, used `has_infs()` instead.
|
||||
"""
|
||||
|
||||
nan_repr: int = NanRepr.IEEE_754.value
|
||||
"""
|
||||
How NaNs are represent in this scalar type, returns NanRepr value.
|
||||
(not applicable for integer types)
|
||||
"""
|
||||
|
||||
@property
|
||||
def size_bits(self):
|
||||
return self.exponent + self.mantissa + int(self.signed)
|
||||
|
||||
def min(self) -> Union[int, float]:
|
||||
"""
|
||||
Min representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def max(self) -> Union[int, float]:
|
||||
"""
|
||||
Max representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_signed(self) -> bool:
|
||||
"""
|
||||
If the type is signed (i.e. has a sign bit), same as `signed`
|
||||
added for consistency with:
|
||||
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
||||
"""
|
||||
...
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
"If the type is a floating point type"
|
||||
return self.exponent != 0
|
||||
|
||||
def is_integer(self) -> bool:
|
||||
"If the type is an integer type"
|
||||
return self.exponent == 0
|
||||
|
||||
def has_bias(self) -> bool:
|
||||
"If the type has a non-zero bias"
|
||||
return self.bias != 0
|
||||
|
||||
def has_infs(self) -> bool:
|
||||
"If the type is floating point and supports infinity"
|
||||
return not self._finite_values_only
|
||||
|
||||
def has_nans(self) -> bool:
|
||||
return self.nan_repr != NanRepr.NONE.value
|
||||
|
||||
def is_ieee_754(self) -> bool:
|
||||
"""
|
||||
If the type is a floating point type that follows IEEE 754
|
||||
conventions
|
||||
"""
|
||||
return self.nan_repr == NanRepr.IEEE_754.value and \
|
||||
not self._finite_values_only
|
||||
|
||||
def __str__(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
||||
# opcheck to work.
|
||||
def __len__(self) -> int:
|
||||
raise TypeError
|
||||
|
||||
#
|
||||
# Convenience Constructors
|
||||
#
|
||||
|
||||
@classmethod
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
||||
return cls(size_bits - 1, size_bits, bias if bias else 0, True)
|
||||
|
||||
@classmethod
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
"""Create a unsigned integer scalar type."""
|
||||
return cls(size_bits, size_bits, bias if bias else 0, False)
|
||||
|
||||
@classmethod
|
||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
||||
"""
|
||||
Create a standard floating point type
|
||||
(i.e. follows IEEE 754 conventions).
|
||||
"""
|
||||
return cls(exponent, mantissa, 0, True)
|
||||
|
||||
@classmethod
|
||||
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
||||
nan_repr: int) -> 'ScalarType':
|
||||
"""
|
||||
Create a non-standard floating point type
|
||||
(i.e. does not follow IEEE 754 conventions).
|
||||
"""
|
||||
return cls(exponent, mantissa, 0, True, finite_values_only,
|
||||
nan_repr)
|
||||
|
||||
elif core_C_available:
|
||||
try:
|
||||
import vllm._core_C # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._core_C with %r", e)
|
||||
|
||||
ScalarType = torch.classes._core_C.ScalarType
|
||||
|
||||
if (hasattr(torch, "_library")
|
||||
and hasattr(torch._library, "register_fake_class")):
|
||||
# Needed for dynamo support of ScalarType.
|
||||
@torch._library.register_fake_class("_core_C::ScalarType")
|
||||
class FakeScalarType:
|
||||
|
||||
def __init__(self, scalar_type):
|
||||
self.ScalarType = scalar_type
|
||||
|
||||
def bias_getter(self) -> int:
|
||||
return self.ScalarType.bias
|
||||
|
||||
def exponent_getter(self) -> int:
|
||||
return self.ScalarType.exponent
|
||||
|
||||
def mantissa_getter(self) -> int:
|
||||
return self.ScalarType.mantissa
|
||||
|
||||
def signed_getter(self) -> bool:
|
||||
return self.ScalarType.signed
|
||||
|
||||
def size_bits_getter(self) -> int:
|
||||
return self.ScalarType.size_bits
|
||||
|
||||
@property
|
||||
def size_bits(self) -> int:
|
||||
return self.ScalarType.size_bits
|
||||
|
||||
def min(self) -> Union[int, float]:
|
||||
return self.ScalarType.min()
|
||||
|
||||
def max(self) -> Union[int, float]:
|
||||
return self.ScalarType.max()
|
||||
|
||||
def is_signed(self) -> bool:
|
||||
return self.ScalarType.is_signed()
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
return self.ScalarType.is_floating_point()
|
||||
|
||||
def is_integer(self) -> bool:
|
||||
return self.ScalarType.is_integer()
|
||||
|
||||
def has_bias(self) -> bool:
|
||||
return self.ScalarType.has_bias()
|
||||
|
||||
def has_infs(self) -> bool:
|
||||
return self.ScalarType.has_infs()
|
||||
|
||||
def has_nans(self) -> bool:
|
||||
return self.ScalarType.has_nans()
|
||||
|
||||
def is_ieee_754(self) -> bool:
|
||||
return self.ScalarType.is_ieee_754()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.ScalarType.__str__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.ScalarType.__repr__()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.ScalarType.__len__()
|
||||
|
||||
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
|
||||
return torch.classes._core_C.ScalarType.__obj_flatten__(
|
||||
self.ScalarType)
|
||||
|
||||
@classmethod
|
||||
def __obj_unflatten__(
|
||||
cls, flat_type: Tuple[Tuple[str, Any],
|
||||
...]) -> 'ScalarType':
|
||||
return cls(
|
||||
torch.classes._core_C.ScalarType.__obj_unflatten__(
|
||||
flat_type))
|
||||
|
||||
@classmethod
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
return ScalarType.int_(size_bits, bias)
|
||||
|
||||
@classmethod
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
return ScalarType.uint(size_bits, bias)
|
||||
|
||||
@classmethod
|
||||
def float_IEEE754(cls, exponent: int,
|
||||
mantissa: int) -> 'ScalarType':
|
||||
return ScalarType.float_IEEE754(exponent, mantissa)
|
||||
|
||||
@classmethod
|
||||
def float_(cls, exponent: int, mantissa: int,
|
||||
finite_values_only: bool,
|
||||
nan_repr: int) -> 'ScalarType':
|
||||
return ScalarType.float_(exponent, mantissa,
|
||||
finite_values_only, nan_repr)
|
@ -6,9 +6,9 @@ import torch
|
||||
import torch.library
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._core_ext import ScalarType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -306,7 +306,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
workspace: torch.Tensor, b_q_type: ScalarType,
|
||||
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
|
||||
workspace, b_q_type, size_m,
|
||||
workspace, b_q_type.id, size_m,
|
||||
size_n, size_k)
|
||||
|
||||
|
||||
@ -316,8 +316,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
b_q_type: ScalarType, size_m: int,
|
||||
size_n: int, size_k: int) -> torch.Tensor:
|
||||
b_q_type: ScalarType, size_m: torch.SymInt,
|
||||
size_n: torch.SymInt,
|
||||
size_k: torch.SymInt) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||
|
||||
@register_fake("_C::gptq_marlin_gemm")
|
||||
@ -329,17 +330,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
perm: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
b_q_type: ScalarType,
|
||||
size_m: int,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
size_m: torch.SymInt,
|
||||
size_n: torch.SymInt,
|
||||
size_k: torch.SymInt,
|
||||
is_k_full: bool,
|
||||
has_zp: bool = False,
|
||||
use_fp32_reduce: bool = False) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||
|
||||
@register_fake("_C::ggml_dequantize")
|
||||
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
|
||||
n: int) -> torch.Tensor:
|
||||
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
|
||||
m: torch.SymInt,
|
||||
n: torch.SymInt) -> torch.Tensor:
|
||||
return torch.empty((m, n), dtype=torch.float16, device=W.device)
|
||||
|
||||
@register_fake("_C::ggml_mul_mat_vec_a8")
|
||||
@ -347,7 +349,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
W: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
quant_type: int,
|
||||
row: int,
|
||||
row: torch.SymInt,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((1, row), dtype=torch.float16, device=W.device)
|
||||
|
||||
@ -356,7 +358,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
W: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
quant_type: int,
|
||||
row: int,
|
||||
row: torch.SymInt,
|
||||
) -> torch.Tensor:
|
||||
batch = X.size(0)
|
||||
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
|
||||
@ -365,8 +367,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
||||
s_group: torch.Tensor, workspace: torch.Tensor,
|
||||
size_m: int, size_n: int,
|
||||
size_k: int) -> torch.Tensor:
|
||||
size_m: torch.SymInt, size_n: torch.SymInt,
|
||||
size_k: torch.SymInt) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n),
|
||||
dtype=torch.float16,
|
||||
device=a.device)
|
||||
@ -374,16 +376,16 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
@register_fake("_C::marlin_gemm")
|
||||
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor, workspace: torch.Tensor,
|
||||
size_m: int, size_n: int,
|
||||
size_k: int) -> torch.Tensor:
|
||||
size_m: torch.SymInt, size_n: torch.SymInt,
|
||||
size_k: torch.SymInt) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n),
|
||||
dtype=torch.float16,
|
||||
device=a.device)
|
||||
|
||||
@register_fake("_C::awq_dequantize")
|
||||
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
|
||||
zeros: torch.Tensor, split_k_iters: int, thx: int,
|
||||
thy: int) -> torch.Tensor:
|
||||
zeros: torch.Tensor, split_k_iters: torch.SymInt,
|
||||
thx: int, thy: int) -> torch.Tensor:
|
||||
in_c = qweight.size(0)
|
||||
qout_c = qweight.size(1)
|
||||
out_c = qout_c * 8
|
||||
@ -394,7 +396,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
@register_fake("_C::awq_gemm")
|
||||
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
|
||||
qzeros: torch.Tensor, scales: torch.Tensor,
|
||||
split_k_iters: int) -> torch.Tensor:
|
||||
split_k_iters: torch.SymInt) -> torch.Tensor:
|
||||
num_in_feats = input.size(0)
|
||||
return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
|
||||
dtype=input.dtype,
|
||||
@ -429,8 +431,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
@register_fake("_C::fp8_marlin_gemm")
|
||||
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor, workspace: torch.Tensor,
|
||||
num_bits: int, size_m: int, size_n: int,
|
||||
size_k: int) -> torch.Tensor:
|
||||
num_bits: int, size_m: torch.SymInt,
|
||||
size_n: torch.SymInt,
|
||||
size_k: torch.SymInt) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
||||
|
||||
@register_fake("_C::machete_gemm")
|
||||
@ -457,40 +460,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
return torch.empty_like(b_q_weight,
|
||||
memory_format=torch.contiguous_format)
|
||||
|
||||
@register_fake("_C::causal_conv1d_fwd")
|
||||
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor],
|
||||
conv_states: Optional[torch.Tensor],
|
||||
cu_seq_len: Optional[torch.Tensor],
|
||||
cache_indices: Optional[torch.Tensor],
|
||||
has_initial_state: Optional[torch.Tensor],
|
||||
silu_activation: bool, pad_slot_id: int):
|
||||
return None
|
||||
|
||||
@register_fake("_C::causal_conv1d_update")
|
||||
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor],
|
||||
silu_activation: bool,
|
||||
cache_seqlens: Optional[torch.Tensor],
|
||||
conv_state_indices: Optional[torch.Tensor],
|
||||
pad_slot_id: int) -> None:
|
||||
return None
|
||||
|
||||
@register_fake("_C::selective_scan_fwd")
|
||||
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
|
||||
A: torch.Tensor, B: torch.Tensor,
|
||||
C: torch.Tensor, D_: Optional[torch.Tensor],
|
||||
z_: Optional[torch.Tensor],
|
||||
delta_bias_: Optional[torch.Tensor],
|
||||
delta_softplus: bool,
|
||||
cu_seq_len: Optional[torch.Tensor],
|
||||
cache_indices: Optional[torch.Tensor],
|
||||
has_initial_state: Optional[torch.Tensor],
|
||||
ssm_states: Optional[torch.Tensor],
|
||||
pad_slot_id: int) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# cutlass
|
||||
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
||||
@ -611,7 +580,7 @@ def gptq_marlin_gemm(a: torch.Tensor,
|
||||
has_zp: bool = False,
|
||||
use_fp32_reduce: bool = False) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
|
||||
g_idx, perm, workspace, b_q_type,
|
||||
g_idx, perm, workspace, b_q_type.id,
|
||||
size_m, size_n, size_k, is_k_full,
|
||||
has_zp, use_fp32_reduce)
|
||||
|
||||
@ -627,7 +596,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
|
||||
# machete
|
||||
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
|
||||
return torch.ops._C.machete_supported_schedules(b_type)
|
||||
return torch.ops._C.machete_supported_schedules(b_type.id)
|
||||
|
||||
|
||||
def machete_gemm(
|
||||
@ -642,13 +611,13 @@ def machete_gemm(
|
||||
beta: Optional[float] = None,
|
||||
schedule: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros,
|
||||
return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros,
|
||||
b_group_size, c, alpha, beta, schedule)
|
||||
|
||||
|
||||
def machete_prepack_B(b_q_weight: torch.Tensor,
|
||||
b_type: ScalarType) -> torch.Tensor:
|
||||
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
|
||||
return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "permute_cols"):
|
||||
@ -862,10 +831,10 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
||||
topk_ids: torch.Tensor, b_scales: torch.Tensor,
|
||||
b_zero_points: torch.Tensor, g_idx: torch.Tensor,
|
||||
perm: torch.Tensor, workspace: torch.Tensor,
|
||||
b_q_type: ScalarType, size_m: int, size_n: int,
|
||||
size_k: int, is_k_full: bool, num_experts: int,
|
||||
topk: int, moe_block_size: int,
|
||||
replicate_input: bool,
|
||||
b_q_type: ScalarType, size_m: torch.SymInt,
|
||||
size_n: torch.SymInt, size_k: torch.SymInt,
|
||||
is_k_full: bool, num_experts: int, topk: int,
|
||||
moe_block_size: int, replicate_input: bool,
|
||||
apply_weights: bool) -> torch.Tensor:
|
||||
return torch.empty((size_m, topk, size_n),
|
||||
dtype=a.dtype,
|
||||
|
@ -116,7 +116,7 @@ def single_marlin_moe(
|
||||
|
||||
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
|
||||
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
|
||||
w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K,
|
||||
w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K,
|
||||
is_k_full, E, topk, block_size_m, True, False)
|
||||
|
||||
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
||||
@ -272,7 +272,7 @@ def fused_marlin_moe(
|
||||
g_idx1,
|
||||
sort_indices1,
|
||||
workspace,
|
||||
scalar_type1,
|
||||
scalar_type1.id,
|
||||
M,
|
||||
2 * N,
|
||||
K,
|
||||
@ -297,7 +297,7 @@ def fused_marlin_moe(
|
||||
g_idx2,
|
||||
sort_indices2,
|
||||
workspace,
|
||||
scalar_type2,
|
||||
scalar_type2.id,
|
||||
M,
|
||||
K,
|
||||
N,
|
||||
|
@ -1,4 +1,298 @@
|
||||
from ._core_ext import NanRepr, ScalarType
|
||||
import functools
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
# Mirrors enum in `core/scalar_type.hpp`
|
||||
class NanRepr(Enum):
|
||||
NONE = 0 # nans are not supported
|
||||
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
||||
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
||||
|
||||
|
||||
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
||||
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
||||
# in sync until the inductor fully supports custom C++ classes.
|
||||
@dataclass(frozen=True)
|
||||
class ScalarType:
|
||||
"""
|
||||
ScalarType can represent a wide range of floating point and integer
|
||||
types, in particular it can be used to represent sub-byte data types
|
||||
(something that torch.dtype currently does not support). It is also
|
||||
capable of representing types with a bias, i.e.:
|
||||
`stored_value = value + bias`,
|
||||
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
||||
of 8). The implementation for this class can be found in
|
||||
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
||||
with that file.
|
||||
"""
|
||||
|
||||
exponent: int
|
||||
"""
|
||||
Number of bits in the exponent if this is a floating point type
|
||||
(zero if this an integer type)
|
||||
"""
|
||||
|
||||
mantissa: int
|
||||
"""
|
||||
Number of bits in the mantissa if this is a floating point type,
|
||||
or the number bits representing an integer excluding the sign bit if
|
||||
this an integer type.
|
||||
"""
|
||||
|
||||
signed: bool
|
||||
"If the type is signed (i.e. has a sign bit)"
|
||||
|
||||
bias: int
|
||||
"""
|
||||
bias used to encode the values in this scalar type
|
||||
(value = stored_value - bias, default 0) for example if we store the
|
||||
type as an unsigned integer with a bias of 128 then the value 0 will be
|
||||
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
||||
"""
|
||||
|
||||
_finite_values_only: bool = False
|
||||
"""
|
||||
Private: if infs are supported, used `has_infs()` instead.
|
||||
"""
|
||||
|
||||
nan_repr: NanRepr = NanRepr.IEEE_754
|
||||
"""
|
||||
How NaNs are represent in this scalar type, returns NanRepr value.
|
||||
(not applicable for integer types)
|
||||
"""
|
||||
|
||||
def _floating_point_max_int(self) -> int:
|
||||
assert (
|
||||
self.mantissa <= 52 and self.exponent <= 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
|
||||
max_mantissa = (1 << self.mantissa) - 1
|
||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
||||
max_mantissa = max_mantissa - 1
|
||||
|
||||
max_exponent = (1 << self.exponent) - 2
|
||||
if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
|
||||
or self.nan_repr == NanRepr.NONE):
|
||||
assert (
|
||||
self.exponent < 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
max_exponent = max_exponent + 1
|
||||
|
||||
# adjust the exponent to match that of a double
|
||||
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
||||
# e is the exponent bits), there is some precedent for non-standard
|
||||
# biases, example `float8_e4m3b11fnuz` here:
|
||||
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
||||
# complication we are just assuming the standard exponent bias until
|
||||
# there is a need to support non-standard biases
|
||||
exponent_bias = (1 << (self.exponent - 1)) - 1
|
||||
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
||||
|
||||
max_exponent_double = (max_exponent - exponent_bias +
|
||||
exponent_bias_double)
|
||||
|
||||
# shift the mantissa and exponent into the proper positions for an
|
||||
# IEEE double and bitwise-or them together.
|
||||
return (max_mantissa <<
|
||||
(52 - self.mantissa)) | (max_exponent_double << 52)
|
||||
|
||||
def _floating_point_max(self) -> float:
|
||||
double_raw = self._floating_point_max_int()
|
||||
return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
|
||||
|
||||
def _raw_max(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
return self._floating_point_max()
|
||||
else:
|
||||
assert (self.size_bits < 64 or self.size_bits == 64
|
||||
and self.is_signed()), "Cannot represent max as an int"
|
||||
return (1 << self.mantissa) - 1
|
||||
|
||||
def _raw_min(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
assert self.is_signed(
|
||||
), "We currently assume all floating point types are signed"
|
||||
sign_bit_double = 1 << 63
|
||||
|
||||
max_raw = self._floating_point_max_int()
|
||||
min_raw = max_raw | sign_bit_double
|
||||
return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
|
||||
else:
|
||||
assert (not self.is_signed() or
|
||||
self.size_bits <= 64), "Cannot represent min as a int64_t"
|
||||
|
||||
if self.is_signed():
|
||||
return -(1 << (self.size_bits - 1))
|
||||
else:
|
||||
return 0
|
||||
|
||||
@functools.cached_property
|
||||
def id(self) -> int:
|
||||
"""
|
||||
Convert the ScalarType to an int which can be passed to pytorch custom
|
||||
ops. This layout of the int must be kept in sync with the C++
|
||||
ScalarType's from_id method.
|
||||
"""
|
||||
val = 0
|
||||
offset = 0
|
||||
|
||||
def or_and_advance(member, bit_width):
|
||||
nonlocal val
|
||||
nonlocal offset
|
||||
bit_mask = (1 << bit_width) - 1
|
||||
val = val | (int(member) & bit_mask) << offset
|
||||
offset = offset + bit_width
|
||||
|
||||
or_and_advance(self.exponent, 8)
|
||||
or_and_advance(self.mantissa, 8)
|
||||
or_and_advance(self.signed, 1)
|
||||
or_and_advance(self.bias, 32)
|
||||
or_and_advance(self._finite_values_only, 1)
|
||||
or_and_advance(self.nan_repr.value, 8)
|
||||
|
||||
assert offset <= 64, \
|
||||
f"ScalarType fields too big {offset} to fit into an int64"
|
||||
|
||||
return val
|
||||
|
||||
@property
|
||||
def size_bits(self) -> int:
|
||||
return self.exponent + self.mantissa + int(self.signed)
|
||||
|
||||
def min(self) -> Union[int, float]:
|
||||
"""
|
||||
Min representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
return self._raw_min() - self.bias
|
||||
|
||||
def max(self) -> Union[int, float]:
|
||||
"""
|
||||
Max representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
return self._raw_max() - self.bias
|
||||
|
||||
def is_signed(self) -> bool:
|
||||
"""
|
||||
If the type is signed (i.e. has a sign bit), same as `signed`
|
||||
added for consistency with:
|
||||
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
||||
"""
|
||||
return self.signed
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
"If the type is a floating point type"
|
||||
return self.exponent != 0
|
||||
|
||||
def is_integer(self) -> bool:
|
||||
"If the type is an integer type"
|
||||
return self.exponent == 0
|
||||
|
||||
def has_bias(self) -> bool:
|
||||
"If the type has a non-zero bias"
|
||||
return self.bias != 0
|
||||
|
||||
def has_infs(self) -> bool:
|
||||
"If the type is floating point and supports infinity"
|
||||
return not self._finite_values_only
|
||||
|
||||
def has_nans(self) -> bool:
|
||||
return self.nan_repr != NanRepr.NONE.value
|
||||
|
||||
def is_ieee_754(self) -> bool:
|
||||
"""
|
||||
If the type is a floating point type that follows IEEE 754
|
||||
conventions
|
||||
"""
|
||||
return self.nan_repr == NanRepr.IEEE_754.value and \
|
||||
not self._finite_values_only
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
for floating point types (leading f) the scheme is:
|
||||
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
flags:
|
||||
- no-flags: means it follows IEEE 754 conventions
|
||||
- f: means finite values only (no infinities)
|
||||
- n: means nans are supported (non-standard encoding)
|
||||
for integer types the scheme is:
|
||||
`[u]int<size_bits>[b<bias>]`
|
||||
- if bias is not present it means its zero
|
||||
"""
|
||||
if self.is_floating_point():
|
||||
ret = "float" + str(self.size_bits) + "_e" + str(
|
||||
self.exponent) + "m" + str(self.mantissa)
|
||||
|
||||
if not self.is_ieee_754():
|
||||
if self._finite_values_only:
|
||||
ret = ret + "f"
|
||||
if self.nan_repr != NanRepr.NONE:
|
||||
ret = ret + "n"
|
||||
|
||||
return ret
|
||||
else:
|
||||
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
||||
if self.has_bias():
|
||||
ret = ret + "b" + str(self.bias)
|
||||
return ret
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "ScalarType." + self.__str__()
|
||||
|
||||
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
||||
# opcheck to work.
|
||||
def __len__(self) -> int:
|
||||
raise TypeError
|
||||
|
||||
#
|
||||
# Convenience Constructors
|
||||
#
|
||||
|
||||
@classmethod
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
||||
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
"""Create a unsigned integer scalar type."""
|
||||
ret = cls(0, size_bits, False, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
||||
"""
|
||||
Create a standard floating point type
|
||||
(i.e. follows IEEE 754 conventions).
|
||||
"""
|
||||
assert (mantissa > 0 and exponent > 0)
|
||||
ret = cls(exponent, mantissa, True, 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
||||
nan_repr: NanRepr) -> 'ScalarType':
|
||||
"""
|
||||
Create a non-standard floating point type
|
||||
(i.e. does not follow IEEE 754 conventions).
|
||||
"""
|
||||
assert (mantissa > 0 and exponent > 0)
|
||||
assert (nan_repr != NanRepr.IEEE_754), (
|
||||
"use `float_IEEE754` constructor for floating point types that "
|
||||
"follow IEEE 754 conventions")
|
||||
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
|
||||
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
# for floating point types (leading f) the scheme is:
|
||||
@ -17,14 +311,13 @@ class scalar_types:
|
||||
uint4 = ScalarType.uint(4, None)
|
||||
int8 = ScalarType.int_(8, None)
|
||||
uint8 = ScalarType.uint(8, None)
|
||||
float8_e4m3fn = ScalarType.float_(4, 3, True,
|
||||
NanRepr.EXTD_RANGE_MAX_MIN.value)
|
||||
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
||||
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
||||
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
||||
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
||||
|
||||
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
||||
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)
|
||||
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
||||
|
||||
# "gptq" types
|
||||
uint2b2 = ScalarType.uint(2, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user