2024-08-20 09:09:33 -04:00
|
|
|
#include "machete_mm_launcher.cuh"
|
|
|
|
#include "machete_prepack_launcher.cuh"
|
|
|
|
#include "core/scalar_type.hpp"
|
|
|
|
|
|
|
|
namespace machete {
|
|
|
|
|
|
|
|
using namespace vllm;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Utils (type dispatching)
|
|
|
|
//
|
|
|
|
|
|
|
|
template <typename Fn>
|
|
|
|
static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
|
|
|
|
if (type == vllm::kU4) {
|
|
|
|
return fn(cutlass::uint4b_t{});
|
|
|
|
} else if (type == vllm::kU8) {
|
|
|
|
return fn(cutlass::uint8_t{});
|
|
|
|
} else if (type == vllm::kU4B8) {
|
|
|
|
return fn(cutlass::vllm_uint4b8_t{});
|
|
|
|
} else if (type == vllm::kU8B128) {
|
|
|
|
return fn(cutlass::vllm_uint8b128_t{});
|
|
|
|
} else {
|
|
|
|
TORCH_CHECK(false, "Unsupported type ", type.str());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
|
|
|
|
AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)
|
|
|
|
|
|
|
|
#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
|
|
|
|
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
|
|
|
AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))
|
|
|
|
|
|
|
|
//
|
|
|
|
// Interface
|
|
|
|
//
|
|
|
|
|
|
|
|
std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
|
2024-08-22 08:28:52 -04:00
|
|
|
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
2024-08-20 09:09:33 -04:00
|
|
|
return scalar_type_dispatch(*btype, [&](auto BType) {
|
|
|
|
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
|
|
|
|
});
|
2024-08-22 08:28:52 -04:00
|
|
|
#else
|
|
|
|
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
|
|
|
#endif
|
2024-08-20 09:09:33 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
|
|
|
ScalarTypeTorchPtr const& btype,
|
|
|
|
c10::optional<torch::Tensor> const& scales,
|
|
|
|
c10::optional<torch::Tensor> const& zeros,
|
|
|
|
c10::optional<int64_t> group_size,
|
|
|
|
c10::optional<torch::Tensor> const& C,
|
|
|
|
c10::optional<double> alpha, c10::optional<double> beta,
|
|
|
|
c10::optional<std::string> schedule) {
|
2024-08-22 08:28:52 -04:00
|
|
|
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
2024-08-20 09:09:33 -04:00
|
|
|
auto args = PyTorchArguments{.A = A,
|
|
|
|
.B = B,
|
|
|
|
.scales = scales,
|
|
|
|
.zeros = zeros,
|
|
|
|
.group_size = group_size,
|
|
|
|
.C = C,
|
|
|
|
.alpha = alpha,
|
|
|
|
.beta = beta,
|
|
|
|
.schedule = schedule};
|
|
|
|
|
|
|
|
return scalar_type_dispatch(*btype, [&](auto BType) {
|
|
|
|
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
|
|
|
|
A.scalar_type(), "machete_gemm", [&] {
|
|
|
|
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
|
|
|
|
return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
|
|
|
|
});
|
|
|
|
});
|
2024-08-22 08:28:52 -04:00
|
|
|
#else
|
|
|
|
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
|
|
|
#endif
|
2024-08-20 09:09:33 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
torch::Tensor prepack_B(torch::Tensor const& B,
|
|
|
|
ScalarTypeTorchPtr const& btype) {
|
2024-08-22 08:28:52 -04:00
|
|
|
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
2024-08-20 09:09:33 -04:00
|
|
|
return scalar_type_dispatch(*btype, [&](auto BType) {
|
|
|
|
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
|
|
|
|
});
|
2024-08-22 08:28:52 -04:00
|
|
|
#else
|
|
|
|
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
|
|
|
#endif
|
2024-08-20 09:09:33 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
}; // namespace machete
|