[Bugfix] Fix support for dimension like integers and ScalarType (#9299)

This commit is contained in:
bnellnm 2024-10-17 15:08:34 -04:00 committed by GitHub
parent 0f41fbe5a3
commit eca2c5f7c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 427 additions and 677 deletions

View File

@ -230,14 +230,12 @@ steps:
commands:
- pytest -v -s compile/test_basic_correctness.py
# TODO: re-write in comparison tests, and fix symbolic shape
# for quantization ops.
# - label: "PyTorch Fullgraph Test" # 18min
# source_file_dependencies:
# - vllm/
# - tests/compile
# commands:
# - pytest -v -s compile/test_full_graph.py
- label: "PyTorch Fullgraph Test" # 18min
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph.py
- label: Kernels Test %N # 1h each
mirror_hardwares: [amd]

View File

@ -83,24 +83,6 @@ endif()
#
find_package(Torch REQUIRED)
#
message(STATUS "Enabling core extension.")
# Define _core_C extension
# built for (almost) every target platform, (excludes TPU and Neuron)
set(VLLM_EXT_SRC
"csrc/core/torch_bindings.cpp")
define_gpu_extension_target(
_core_C
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3
WITH_SOABI)
#
# Forward the non-CUDA device extensions to external CMake scripts.
#

View File

@ -1,6 +1,7 @@
#pragma once
#include <torch/custom_class.h>
// For TORCH_CHECK
#include <torch/library.h>
namespace vllm {
@ -9,12 +10,7 @@ namespace vllm {
// in particular it can be used to represent sub-byte data types (something
// that torch.dtype currently does not support).
//
// ScalarTypeTorch is a subclass of ScalarType that is compatible with
// TORCH_LIBRARY, making it accessible from Python as well meaning this class
// can be used as a argument for custom operators, helping to simplify these
// interfaces.
//
// The type definitions on the Python side can be found in: vllm/_core_ext.pyi
// The type definitions on the Python side can be found in: vllm/scalar_type.py
// these type definitions should be kept up to date with any Python API changes
// here.
//
@ -308,204 +304,7 @@ class ScalarType {
}
};
// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from
// torch::CustomClassHolder), we use multiple inheritance here since we cannot
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
// constructor at the same time (torch::CustomClassHolder does not have a
// constexpr destructor)
// See also:
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
public:
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
bool _signed)
: ScalarType(exponent, mantissa, bias, _signed){};
ScalarTypeTorch(ScalarType type) : ScalarType(type){};
using Base = ScalarType;
using Self = ScalarTypeTorch;
using SelfPtr = c10::intrusive_ptr<Self>;
static void check_size_bits(int64_t size_bits, bool signed_) {
TORCH_CHECK(
size_bits <=
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
"size_bits bit width is too large to be represented");
}
static void check_bias(int64_t bias) {
using Bias = decltype(std::declval<Self>().bias);
TORCH_CHECK(bias <= std::numeric_limits<Bias>::max() &&
bias >= std::numeric_limits<Bias>::min(),
"bias too large or small to be represented");
}
static void check_exponent(int64_t exponent) {
TORCH_CHECK(
exponent <=
std::numeric_limits<decltype(std::declval<Self>().exponent)>::max(),
"exponent bit width is too large to be represented");
}
static void check_mantissa(int64_t mantissa) {
TORCH_CHECK(
mantissa <=
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
"mantissa bit width is too large to be represented");
}
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
check_size_bits(size_bits, true);
check_bias(bias.value_or(0));
return c10::make_intrusive<Self>(
ScalarType::int_(size_bits, bias.value_or(0)));
}
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
check_size_bits(size_bits, true);
check_bias(bias.value_or(0));
return c10::make_intrusive<Self>(
ScalarType::uint(size_bits, bias.value_or(0)));
}
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
check_mantissa(mantissa);
check_exponent(exponent);
return c10::make_intrusive<Self>(
ScalarType::float_IEEE754(exponent, mantissa));
}
static SelfPtr float_(int64_t exponent, int64_t mantissa,
bool finite_values_only, int64_t nan_repr) {
check_mantissa(mantissa);
check_exponent(exponent);
return c10::make_intrusive<Self>(ScalarType::float_(
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
}
// This needs to be implemented and throw a TypeError in order for
// PyTorch's opcheck to work on ops that use ScalarTypes.
int64_t len() const {
throw c10::TypeError({__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
"__len__ not implemented");
return 0;
}
// Serialize a ScalarType into a tuple of pairs. Where each pair
// is a (fieldname, value).
// For simplicity, we are just going to convert to a ScalarTypeId.
std::tuple<std::tuple<std::string, int64_t>> obj_flatten() const {
return {{"ScalarType", id()}};
}
// Deserialize a scalar type that has been serialized by obj_flatten,
// ostensibly from a tuple of (member name, value) pairs, but in reality
// just a ScalarTypeId.
static SelfPtr obj_unflatten(
std::tuple<std::tuple<std::string, int64_t>> const& flat_type) {
return c10::make_intrusive<Self>(
from_id(std::get<1>(std::get<0>(flat_type))));
}
template <typename T>
static void bind_readonly_property(torch::class_<Self>& cls,
std::string const& name, T Base::*field) {
auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) {
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
return (self.get()->*field)();
} else {
return self.get()->*field;
}
};
auto getter_func = [field = std::move(field),
getter_func_helper = std::move(getter_func_helper)](
SelfPtr const& self) {
auto val = getter_func_helper(self);
// upconvert uint8_t, int32_t etc. to int64_t for python
if constexpr (std::is_integral_v<T>) {
return static_cast<int64_t>(val);
} else {
return val;
}
};
cls.def_property(name, getter_func);
}
template <typename MemberFunc, typename Cls>
static void bind_function(torch::class_<Self>& cls, const std::string& name,
MemberFunc Cls::*member) {
cls.def(name, [member = std::move(member)](SelfPtr const& self) {
return (self.get()->*member)();
});
}
template <typename Func>
static void bind_function(torch::class_<Self>& cls, const std::string& name,
Func func) {
cls.def(name, func);
}
template <typename Func>
static void bind_static_function(torch::class_<Self>& cls,
const std::string& name, Func func) {
cls.def_static(name, func);
}
static void bind_class(torch::Library& lib) {
auto cls = lib.class_<ScalarTypeTorch>("ScalarType")
.def(torch::init<int64_t, int64_t, int64_t, bool>());
// Bind Properties
bind_readonly_property(cls, "mantissa", &Base::mantissa);
bind_readonly_property(cls, "exponent", &Base::exponent);
bind_readonly_property(cls, "bias", &Base::bias);
bind_readonly_property(cls, "signed", &Base::is_signed);
bind_readonly_property(cls, "size_bits", &Base::size_bits);
// Bind member functions
bind_function(cls, "is_signed", &Base::is_signed);
bind_function(cls, "is_integer", &Base::is_integer);
bind_function(cls, "is_floating_point", &Base::is_floating_point);
bind_function(cls, "is_ieee_754", &Base::is_ieee_754);
bind_function(cls, "has_nans", &Base::has_nans);
bind_function(cls, "has_infs", &Base::has_infs);
bind_function(cls, "has_bias", &Base::has_bias);
bind_function(cls, "max", [](SelfPtr const& self) {
return std::visit([](auto arg) { return c10::IValue(arg); },
self.get()->max());
});
bind_function(cls, "min", [](SelfPtr const& self) {
return std::visit([](auto arg) { return c10::IValue(arg); },
self.get()->min());
});
bind_function(cls, "__len__", &ScalarTypeTorch::len);
bind_function(cls, "__str__", &Base::str);
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
return *self == *other;
});
bind_function(cls, "__repr__", [](SelfPtr const& self) {
return "ScalarType." + self.get()->str();
});
bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten);
bind_static_function(cls, "__obj_unflatten__",
&ScalarTypeTorch::obj_unflatten);
// Bind static functions (convenience constructors)
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754);
bind_static_function(cls, "float_", &ScalarTypeTorch::float_);
}
};
using ScalarTypeId = int64_t;
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
using ScalarTypeId = ScalarType::Id;
// "rust style" names generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70

View File

@ -1,16 +0,0 @@
#include <torch/library.h>
#include "scalar_type.hpp"
#include "registration.h"
// Note the CORE exstension will be built for (almost) all hardware targets so
// new additions must account for this. (currently not built for TPU and Neuron)
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) {
// ScalarType, a custom class for representing data types that supports
// quantized types, declared here so it can be used when creating interfaces
// for custom ops.
vllm::ScalarTypeTorch::bind_class(lib);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@ -484,21 +484,22 @@ torch::Tensor marlin_gemm_moe(
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
const torch::Tensor& perm, torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n,
vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
bool has_zp = b_zeros.size(1) != 0;
if (has_zp) {
TORCH_CHECK(
*b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str());
b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str());
}
int pack_factor = 32 / b_q_type->size_bits();
int pack_factor = 32 / b_q_type.size_bits();
int max_par = 4;
@ -575,7 +576,7 @@ torch::Tensor marlin_gemm_moe(
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
*b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
num_experts, topk, moe_block_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
replicate_input, apply_weights);

View File

@ -13,8 +13,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int b_q_type, SymInt size_m, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
// conditionally compiled so impl registration is in source file

View File

@ -80,7 +80,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
vllm::ScalarTypeId const b_q_type_id,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
@ -2132,22 +2132,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
vllm::ScalarTypeId const& b_q_type_id,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp,
bool use_fp32_reduce) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
if (has_zp) {
TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ",
b_q_type->str());
TORCH_CHECK(
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
b_q_type->str());
b_q_type.str());
}
int pack_factor = 32 / b_q_type->size_bits();
int pack_factor = 32 / b_q_type.size_bits();
// Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
@ -2279,7 +2280,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
@ -2288,7 +2289,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else {

View File

@ -38,9 +38,10 @@ static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
// Interface
//
std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
std::vector<std::string> supported_schedules(ScalarTypeId const btype_id) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
return scalar_type_dispatch(*btype, [&](auto BType) {
vllm::ScalarType b_type = ScalarType::from_id(btype_id);
return scalar_type_dispatch(b_type, [&](auto BType) {
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
});
#else
@ -49,7 +50,7 @@ std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
}
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
ScalarTypeTorchPtr const& btype,
ScalarTypeId const btype_id,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
@ -57,6 +58,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
ScalarType const btype = ScalarType::from_id(btype_id);
auto args = PyTorchArguments{.A = A,
.B = B,
.scales = scales,
@ -67,7 +69,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
.beta = beta,
.schedule = schedule};
return scalar_type_dispatch(*btype, [&](auto BType) {
return scalar_type_dispatch(btype, [&](auto BType) {
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
A.scalar_type(), "machete_gemm", [&] {
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
@ -79,9 +81,9 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
#endif
}
torch::Tensor prepack_B(torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype) {
return scalar_type_dispatch(*btype, [&](auto BType) {
torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) {
ScalarType const btype = ScalarType::from_id(btype_id);
return scalar_type_dispatch(btype, [&](auto BType) {
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
});
}

View File

@ -89,7 +89,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
vllm::ScalarTypeId const b_q_type_id,
int64_t size_m, int64_t size_n,
int64_t size_k) {
TORCH_CHECK_NOT_IMPLEMENTED(
@ -1029,13 +1029,14 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
vllm::ScalarTypeId const b_q_type_id,
int64_t size_m, int64_t size_n,
int64_t size_k) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
// Verify num_bits
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str());
int pack_factor = 32 / b_q_type->size_bits();
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type.str());
int pack_factor = 32 / b_q_type.size_bits();
// Verify M
TORCH_CHECK(size_m == a.size(0),
@ -1130,8 +1131,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
marlin_24::marlin_cuda_2_4(
a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
b_q_type->size_bits(), groupsize, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par);
b_q_type.size_bits(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_m, sms, max_par);
return c;
}

View File

@ -140,13 +140,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantized GEMM for AWQ.
ops.def(
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, int split_k_iters) -> Tensor");
"Tensor _zeros, SymInt split_k_iters) -> Tensor");
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
// Dequantization for AWQ.
ops.def(
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor");
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
// Note about marlin kernel 'workspace' arguments:
@ -166,32 +166,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
ops.def(
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
"Tensor");
// conditionally compiled so impl in source file
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
"Tensor b_scales, Tensor workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k) -> Tensor");
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
// conditionally compiled so impl in source file
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
ops.def("machete_supported_schedules(int btype) -> str[]");
ops.def(
"machete_supported_schedules("
" __torch__.torch.classes._core_C.ScalarType btype"
") -> str[]");
ops.def(
"machete_gemm(Tensor A, Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype,"
"machete_gemm(Tensor A, Tensor B, int btype, "
" Tensor? scales, Tensor? zeros, int? group_size, "
" Tensor? C, float? alpha, float? beta, str? schedule)"
"-> Tensor");
ops.def(
"machete_prepack_B(Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype)"
"-> Tensor");
ops.def("machete_prepack_B(Tensor B, int btype) -> Tensor");
// conditionally compiled so impl registration is in source file
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
@ -201,8 +195,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k, bool is_k_full, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor");
// conditionally compiled so impl registration is in source file
@ -219,32 +213,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// conditionally compiled so impl registrations are in source file
// Dequantization for GGML.
ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor");
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
// mmvq kernel for GGML.
ops.def(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) "
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
"-> Tensor");
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
// mmq kernel for GGML.
ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor");
ops.def(
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops.def(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, int size_m, int size_n, "
"int size_k) -> Tensor");
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
"SymInt size_k) -> Tensor");
// conditionally compiled so impl registration is in source file
// marlin_qqq_gemm for QQQ.
ops.def(
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
"Tensor! workspace, int size_m, int size_n, "
"int size_k) -> Tensor");
"Tensor! workspace, SymInt size_m, SymInt size_n, "
"SymInt size_k) -> Tensor");
// conditionally compiled so impl registration is in source file
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column

View File

@ -39,7 +39,6 @@ assert cwd != package_path, "should not import from the current directory"
files_to_copy = [
"vllm/_C.abi3.so",
"vllm/_core_C.abi3.so",
"vllm/_moe_C.abi3.so",
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so",
"vllm/vllm_flash_attn/flash_attn_interface.py",

View File

@ -290,10 +290,6 @@ def _build_custom_ops() -> bool:
return _is_cuda() or _is_hip() or _is_cpu()
def _build_core_ext() -> bool:
return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu())
def get_hipcc_rocm_version():
# Run the hipcc --version command
result = subprocess.run(['hipcc', '--version'],
@ -456,9 +452,6 @@ def get_requirements() -> List[str]:
ext_modules = []
if _build_core_ext():
ext_modules.append(CMakeExtension(name="vllm._core_C"))
if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))

View File

@ -69,11 +69,11 @@ def check_full_graph_support(model,
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
# Inductor doesn't support fp8/gptq_marlin_24 yet.
# Inductor doesn't support fp8 and the base meta llama uses too
# much memory.
quantization = model_kwargs.get("quantization")
if (quantization == "fp8" or quantization == "gptq_marlin"
or quantization == "gptq_marlin_24"
) and optimization_level >= CompilationLevel.INDUCTOR:
if ((quantization == "fp8" or model == "meta-llama/Meta-Llama-3-8B")
and optimization_level >= CompilationLevel.INDUCTOR):
return
prompts = [

View File

@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor,
w_q = w_q.t().contiguous().t() # convert to col major
w_q_machete = ops.machete_prepack_B(w_q, wtype)
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype))
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id))
return w_ref, w_q_machete, w_s, w_zp
@ -153,8 +153,9 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
schedule=schedule,
)
opcheck(torch.ops._C.machete_gemm,
(a, w_q_machete, wtype, w_s, maybe_convert_zeropoints(
opcheck(
torch.ops._C.machete_gemm,
(a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints(
w_zp, w_s), group_size, None, None, None, schedule))
# Relax atol as our reduction dim becomes larger (more rounding error)

View File

@ -225,7 +225,7 @@ def test_gptq_marlin_gemm(
opcheck(
torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1],
workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, False, use_fp32_reduce),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
@ -254,6 +254,16 @@ def test_gptq_marlin_gemm(
assert max_diff < 0.04
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m, size_n,
size_k):
return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m,
size_n, size_k)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
opcheck(torch.ops._C.gptq_marlin_24_gemm,
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
workspace_24.scratch, quant_type, a_input.shape[0],
workspace_24.scratch, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1]),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_24_gemm(
output = marlin_24_gemm_tester(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,

View File

@ -240,8 +240,8 @@ def test_fused_marlin_moe(
requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m,
2 * n, k, True, e, topk, block_size_m, True, False))
scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
m, 2 * n, k, True, e, topk, block_size_m, True, False))
@pytest.mark.skip("This test is here for the sake of debugging, "

View File

@ -32,5 +32,5 @@ def test_scalar_type_min_max(type_tuple):
max = torch.iinfo(torch_type).max
print(t, min, max, t.min(), t.max())
assert min == t.min()
assert max == t.max()
assert min == t.min(), f"min: {min} != {t.min()}"
assert max == t.max(), f"max: {max} != {t.max()}"

View File

@ -16,7 +16,6 @@ Typical output looks like this:
2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time)
3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time)
Longest build steps for .so (linking):
0.1 weighted s to build _core_C.abi3.so (0.7 s elapsed time)
0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time)
0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time)
6.2 weighted s to build _C.abi3.so (6.2 s elapsed time)

View File

@ -1,278 +0,0 @@
import importlib.util
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None
# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
NONE = 0 # nans are not supported
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
if TYPE_CHECKING or not core_C_available:
# On platforms were we cannot use/build the C++ core extension (i.e. namely
# neuron and tpu), we define the mock ScalarType class here that partially
# mimics the C++ ScalarType class.
#
# We also use this provide type signatures to the Python LSP for the methods
# in the C++ ScalarType class. So these type signatures should be kept
# in sync with csrc/core/scalar_type.hpp
from dataclasses import dataclass
@dataclass(frozen=True)
class ScalarType:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent: int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa: int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
bias: int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
signed: bool
"If the type is signed (i.e. has a sign bit)"
_finite_values_only: bool = False
"""
Private: if NANs are supported, used `has_infs()` instead.
"""
nan_repr: int = NanRepr.IEEE_754.value
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
@property
def size_bits(self):
return self.exponent + self.mantissa + int(self.signed)
def min(self) -> Union[int, float]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError
def max(self) -> Union[int, float]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError
def is_signed(self) -> bool:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
...
def is_floating_point(self) -> bool:
"If the type is a floating point type"
return self.exponent != 0
def is_integer(self) -> bool:
"If the type is an integer type"
return self.exponent == 0
def has_bias(self) -> bool:
"If the type has a non-zero bias"
return self.bias != 0
def has_infs(self) -> bool:
"If the type is floating point and supports infinity"
return not self._finite_values_only
def has_nans(self) -> bool:
return self.nan_repr != NanRepr.NONE.value
def is_ieee_754(self) -> bool:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return self.nan_repr == NanRepr.IEEE_754.value and \
not self._finite_values_only
def __str__(self) -> str:
raise NotImplementedError
def __repr__(self) -> str:
raise NotImplementedError
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def __len__(self) -> int:
raise TypeError
#
# Convenience Constructors
#
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"Create a signed integer scalar type (size_bits includes sign-bit)."
return cls(size_bits - 1, size_bits, bias if bias else 0, True)
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"""Create a unsigned integer scalar type."""
return cls(size_bits, size_bits, bias if bias else 0, False)
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True)
@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: int) -> 'ScalarType':
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True, finite_values_only,
nan_repr)
elif core_C_available:
try:
import vllm._core_C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._core_C with %r", e)
ScalarType = torch.classes._core_C.ScalarType
if (hasattr(torch, "_library")
and hasattr(torch._library, "register_fake_class")):
# Needed for dynamo support of ScalarType.
@torch._library.register_fake_class("_core_C::ScalarType")
class FakeScalarType:
def __init__(self, scalar_type):
self.ScalarType = scalar_type
def bias_getter(self) -> int:
return self.ScalarType.bias
def exponent_getter(self) -> int:
return self.ScalarType.exponent
def mantissa_getter(self) -> int:
return self.ScalarType.mantissa
def signed_getter(self) -> bool:
return self.ScalarType.signed
def size_bits_getter(self) -> int:
return self.ScalarType.size_bits
@property
def size_bits(self) -> int:
return self.ScalarType.size_bits
def min(self) -> Union[int, float]:
return self.ScalarType.min()
def max(self) -> Union[int, float]:
return self.ScalarType.max()
def is_signed(self) -> bool:
return self.ScalarType.is_signed()
def is_floating_point(self) -> bool:
return self.ScalarType.is_floating_point()
def is_integer(self) -> bool:
return self.ScalarType.is_integer()
def has_bias(self) -> bool:
return self.ScalarType.has_bias()
def has_infs(self) -> bool:
return self.ScalarType.has_infs()
def has_nans(self) -> bool:
return self.ScalarType.has_nans()
def is_ieee_754(self) -> bool:
return self.ScalarType.is_ieee_754()
def __str__(self) -> str:
return self.ScalarType.__str__()
def __repr__(self) -> str:
return self.ScalarType.__repr__()
def __len__(self) -> int:
return self.ScalarType.__len__()
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
return torch.classes._core_C.ScalarType.__obj_flatten__(
self.ScalarType)
@classmethod
def __obj_unflatten__(
cls, flat_type: Tuple[Tuple[str, Any],
...]) -> 'ScalarType':
return cls(
torch.classes._core_C.ScalarType.__obj_unflatten__(
flat_type))
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.int_(size_bits, bias)
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.uint(size_bits, bias)
@classmethod
def float_IEEE754(cls, exponent: int,
mantissa: int) -> 'ScalarType':
return ScalarType.float_IEEE754(exponent, mantissa)
@classmethod
def float_(cls, exponent: int, mantissa: int,
finite_values_only: bool,
nan_repr: int) -> 'ScalarType':
return ScalarType.float_(exponent, mantissa,
finite_values_only, nan_repr)

View File

@ -6,9 +6,9 @@ import torch
import torch.library
import vllm.envs as envs
from vllm._core_ext import ScalarType
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType
logger = init_logger(__name__)
@ -306,7 +306,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
workspace: torch.Tensor, b_q_type: ScalarType,
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
workspace, b_q_type, size_m,
workspace, b_q_type.id, size_m,
size_n, size_k)
@ -316,8 +316,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
b_q_type: ScalarType, size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@register_fake("_C::gptq_marlin_gemm")
@ -329,17 +330,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
perm: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor:
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
m: torch.SymInt,
n: torch.SymInt) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)
@register_fake("_C::ggml_mul_mat_vec_a8")
@ -347,7 +349,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
row: torch.SymInt,
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)
@ -356,7 +358,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
row: torch.SymInt,
) -> torch.Tensor:
batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
@ -365,8 +367,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
size_m: torch.SymInt, size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n),
dtype=torch.float16,
device=a.device)
@ -374,16 +376,16 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@register_fake("_C::marlin_gemm")
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
size_m: torch.SymInt, size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n),
dtype=torch.float16,
device=a.device)
@register_fake("_C::awq_dequantize")
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
zeros: torch.Tensor, split_k_iters: torch.SymInt,
thx: int, thy: int) -> torch.Tensor:
in_c = qweight.size(0)
qout_c = qweight.size(1)
out_c = qout_c * 8
@ -394,7 +396,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@register_fake("_C::awq_gemm")
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor, scales: torch.Tensor,
split_k_iters: int) -> torch.Tensor:
split_k_iters: torch.SymInt) -> torch.Tensor:
num_in_feats = input.size(0)
return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
dtype=input.dtype,
@ -429,8 +431,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@register_fake("_C::fp8_marlin_gemm")
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
num_bits: int, size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
@register_fake("_C::machete_gemm")
@ -457,40 +460,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format)
@register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
cu_seq_len: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool, pad_slot_id: int):
return None
@register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor],
pad_slot_id: int) -> None:
return None
@register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
A: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, D_: Optional[torch.Tensor],
z_: Optional[torch.Tensor],
delta_bias_: Optional[torch.Tensor],
delta_softplus: bool,
cu_seq_len: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
ssm_states: Optional[torch.Tensor],
pad_slot_id: int) -> None:
return None
# cutlass
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
@ -611,7 +580,7 @@ def gptq_marlin_gemm(a: torch.Tensor,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, b_q_type,
g_idx, perm, workspace, b_q_type.id,
size_m, size_n, size_k, is_k_full,
has_zp, use_fp32_reduce)
@ -627,7 +596,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# machete
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
return torch.ops._C.machete_supported_schedules(b_type)
return torch.ops._C.machete_supported_schedules(b_type.id)
def machete_gemm(
@ -642,13 +611,13 @@ def machete_gemm(
beta: Optional[float] = None,
schedule: Optional[str] = None,
) -> torch.Tensor:
return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros,
return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros,
b_group_size, c, alpha, beta, schedule)
def machete_prepack_B(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id)
if hasattr(torch.ops._C, "permute_cols"):
@ -862,10 +831,10 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
topk_ids: torch.Tensor, b_scales: torch.Tensor,
b_zero_points: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int, size_n: int,
size_k: int, is_k_full: bool, num_experts: int,
topk: int, moe_block_size: int,
replicate_input: bool,
b_q_type: ScalarType, size_m: torch.SymInt,
size_n: torch.SymInt, size_k: torch.SymInt,
is_k_full: bool, num_experts: int, topk: int,
moe_block_size: int, replicate_input: bool,
apply_weights: bool) -> torch.Tensor:
return torch.empty((size_m, topk, size_n),
dtype=a.dtype,

View File

@ -116,7 +116,7 @@ def single_marlin_moe(
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K,
w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K,
is_k_full, E, topk, block_size_m, True, False)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
@ -272,7 +272,7 @@ def fused_marlin_moe(
g_idx1,
sort_indices1,
workspace,
scalar_type1,
scalar_type1.id,
M,
2 * N,
K,
@ -297,7 +297,7 @@ def fused_marlin_moe(
g_idx2,
sort_indices2,
workspace,
scalar_type2,
scalar_type2.id,
M,
K,
N,

View File

@ -1,4 +1,298 @@
from ._core_ext import NanRepr, ScalarType
import functools
import struct
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union
# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
NONE = 0 # nans are not supported
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
# This ScalarType class is a parallel implementation of the C++ ScalarType
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
# in sync until the inductor fully supports custom C++ classes.
@dataclass(frozen=True)
class ScalarType:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent: int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa: int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
signed: bool
"If the type is signed (i.e. has a sign bit)"
bias: int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
_finite_values_only: bool = False
"""
Private: if infs are supported, used `has_infs()` instead.
"""
nan_repr: NanRepr = NanRepr.IEEE_754
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
def _floating_point_max_int(self) -> int:
assert (
self.mantissa <= 52 and self.exponent <= 11
), f"Cannot represent max/min as a double for type {self.__str__()}"
max_mantissa = (1 << self.mantissa) - 1
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
max_mantissa = max_mantissa - 1
max_exponent = (1 << self.exponent) - 2
if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
or self.nan_repr == NanRepr.NONE):
assert (
self.exponent < 11
), f"Cannot represent max/min as a double for type {self.__str__()}"
max_exponent = max_exponent + 1
# adjust the exponent to match that of a double
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
# e is the exponent bits), there is some precedent for non-standard
# biases, example `float8_e4m3b11fnuz` here:
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
# complication we are just assuming the standard exponent bias until
# there is a need to support non-standard biases
exponent_bias = (1 << (self.exponent - 1)) - 1
exponent_bias_double = (1 << 10) - 1 # double e = 11
max_exponent_double = (max_exponent - exponent_bias +
exponent_bias_double)
# shift the mantissa and exponent into the proper positions for an
# IEEE double and bitwise-or them together.
return (max_mantissa <<
(52 - self.mantissa)) | (max_exponent_double << 52)
def _floating_point_max(self) -> float:
double_raw = self._floating_point_max_int()
return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
def _raw_max(self) -> Union[int, float]:
if self.is_floating_point():
return self._floating_point_max()
else:
assert (self.size_bits < 64 or self.size_bits == 64
and self.is_signed()), "Cannot represent max as an int"
return (1 << self.mantissa) - 1
def _raw_min(self) -> Union[int, float]:
if self.is_floating_point():
assert self.is_signed(
), "We currently assume all floating point types are signed"
sign_bit_double = 1 << 63
max_raw = self._floating_point_max_int()
min_raw = max_raw | sign_bit_double
return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
else:
assert (not self.is_signed() or
self.size_bits <= 64), "Cannot represent min as a int64_t"
if self.is_signed():
return -(1 << (self.size_bits - 1))
else:
return 0
@functools.cached_property
def id(self) -> int:
"""
Convert the ScalarType to an int which can be passed to pytorch custom
ops. This layout of the int must be kept in sync with the C++
ScalarType's from_id method.
"""
val = 0
offset = 0
def or_and_advance(member, bit_width):
nonlocal val
nonlocal offset
bit_mask = (1 << bit_width) - 1
val = val | (int(member) & bit_mask) << offset
offset = offset + bit_width
or_and_advance(self.exponent, 8)
or_and_advance(self.mantissa, 8)
or_and_advance(self.signed, 1)
or_and_advance(self.bias, 32)
or_and_advance(self._finite_values_only, 1)
or_and_advance(self.nan_repr.value, 8)
assert offset <= 64, \
f"ScalarType fields too big {offset} to fit into an int64"
return val
@property
def size_bits(self) -> int:
return self.exponent + self.mantissa + int(self.signed)
def min(self) -> Union[int, float]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
return self._raw_min() - self.bias
def max(self) -> Union[int, float]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
return self._raw_max() - self.bias
def is_signed(self) -> bool:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
return self.signed
def is_floating_point(self) -> bool:
"If the type is a floating point type"
return self.exponent != 0
def is_integer(self) -> bool:
"If the type is an integer type"
return self.exponent == 0
def has_bias(self) -> bool:
"If the type has a non-zero bias"
return self.bias != 0
def has_infs(self) -> bool:
"If the type is floating point and supports infinity"
return not self._finite_values_only
def has_nans(self) -> bool:
return self.nan_repr != NanRepr.NONE.value
def is_ieee_754(self) -> bool:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return self.nan_repr == NanRepr.IEEE_754.value and \
not self._finite_values_only
def __str__(self) -> str:
"""
naming generally follows: https://github.com/jax-ml/ml_dtypes
for floating point types (leading f) the scheme is:
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
flags:
- no-flags: means it follows IEEE 754 conventions
- f: means finite values only (no infinities)
- n: means nans are supported (non-standard encoding)
for integer types the scheme is:
`[u]int<size_bits>[b<bias>]`
- if bias is not present it means its zero
"""
if self.is_floating_point():
ret = "float" + str(self.size_bits) + "_e" + str(
self.exponent) + "m" + str(self.mantissa)
if not self.is_ieee_754():
if self._finite_values_only:
ret = ret + "f"
if self.nan_repr != NanRepr.NONE:
ret = ret + "n"
return ret
else:
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
if self.has_bias():
ret = ret + "b" + str(self.bias)
return ret
def __repr__(self) -> str:
return "ScalarType." + self.__str__()
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def __len__(self) -> int:
raise TypeError
#
# Convenience Constructors
#
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"Create a signed integer scalar type (size_bits includes sign-bit)."
ret = cls(0, size_bits - 1, True, bias if bias else 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"""Create a unsigned integer scalar type."""
ret = cls(0, size_bits, False, bias if bias else 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
assert (mantissa > 0 and exponent > 0)
ret = cls(exponent, mantissa, True, 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: NanRepr) -> 'ScalarType':
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
assert (mantissa > 0 and exponent > 0)
assert (nan_repr != NanRepr.IEEE_754), (
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions")
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
ret.id # noqa B018: make sure the id is cached
return ret
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
@ -17,14 +311,13 @@ class scalar_types:
uint4 = ScalarType.uint(4, None)
int8 = ScalarType.int_(8, None)
uint8 = ScalarType.uint(8, None)
float8_e4m3fn = ScalarType.float_(4, 3, True,
NanRepr.EXTD_RANGE_MAX_MIN.value)
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
# "gptq" types
uint2b2 = ScalarType.uint(2, 2)