vllm/csrc/quantization/machete/machete_pytorch.cu

92 lines
3.0 KiB
Plaintext
Raw Normal View History

#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
#include "core/scalar_type.hpp"
namespace machete {
using namespace vllm;
//
// Utils (type dispatching)
//
template <typename Fn>
static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
if (type == vllm::kU4) {
return fn(cutlass::uint4b_t{});
} else if (type == vllm::kU8) {
return fn(cutlass::uint8_t{});
} else if (type == vllm::kU4B8) {
return fn(cutlass::vllm_uint4b8_t{});
} else if (type == vllm::kU8B128) {
return fn(cutlass::vllm_uint8b128_t{});
} else {
TORCH_CHECK(false, "Unsupported type ", type.str());
}
}
#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)
#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))
//
// Interface
//
std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
return scalar_type_dispatch(*btype, [&](auto BType) {
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
ScalarTypeTorchPtr const& btype,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
auto args = PyTorchArguments{.A = A,
.B = B,
.scales = scales,
.zeros = zeros,
.group_size = group_size,
.C = C,
.alpha = alpha,
.beta = beta,
.schedule = schedule};
return scalar_type_dispatch(*btype, [&](auto BType) {
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
A.scalar_type(), "machete_gemm", [&] {
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
});
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
torch::Tensor prepack_B(torch::Tensor const& B,
ScalarTypeTorchPtr const& btype) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
return scalar_type_dispatch(*btype, [&](auto BType) {
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
}; // namespace machete